Unverified Commit 9560ffa1 authored by Yimei Sun's avatar Yimei Sun Committed by GitHub

Replace copy_with_new_args in A set of ops (#4424)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d6be21dc
...@@ -31,7 +31,7 @@ op::v0::Add::Add(const Output<Node>& arg0, ...@@ -31,7 +31,7 @@ op::v0::Add::Add(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Add::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob());
...@@ -82,7 +82,7 @@ bool op::v1::Add::visit_attributes(AttributeVisitor& visitor) ...@@ -82,7 +82,7 @@ bool op::v1::Add::visit_attributes(AttributeVisitor& visitor)
return true; return true;
} }
shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Add::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<op::v1::Add>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Add>(new_args.at(0), new_args.at(1), this->get_autob());
......
...@@ -53,7 +53,8 @@ namespace ngraph ...@@ -53,7 +53,8 @@ namespace ngraph
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
...@@ -94,7 +95,8 @@ namespace ngraph ...@@ -94,7 +95,8 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
......
...@@ -34,7 +34,7 @@ op::All::All(const Output<Node>& arg, const Output<Node>& reduction_axes) ...@@ -34,7 +34,7 @@ op::All::All(const Output<Node>& arg, const Output<Node>& reduction_axes)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::All::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::All::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<All>(new_args.at(0), new_args.at(1)); return make_shared<All>(new_args.at(0), new_args.at(1));
......
...@@ -43,7 +43,8 @@ namespace ngraph ...@@ -43,7 +43,8 @@ namespace ngraph
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const Output<Node>& reduction_axes); All(const Output<Node>& arg, const Output<Node>& reduction_axes);
bool visit_attributes(AttributeVisitor& visitor) override { return true; } bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The default value for All. /// \return The default value for All.
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
......
...@@ -43,7 +43,7 @@ void op::AllReduce::validate_and_infer_types() ...@@ -43,7 +43,7 @@ void op::AllReduce::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::AllReduce::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<AllReduce>(new_args.at(0), get_reduce_type()); return make_shared<AllReduce>(new_args.at(0), get_reduce_type());
......
...@@ -37,7 +37,8 @@ namespace ngraph ...@@ -37,7 +37,8 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
reduction::Type get_reduce_type() const; reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type); void set_reduce_type(reduction::Type reduce_type);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
......
...@@ -35,7 +35,7 @@ bool op::v1::LogicalAnd::visit_attributes(AttributeVisitor& visitor) ...@@ -35,7 +35,7 @@ bool op::v1::LogicalAnd::visit_attributes(AttributeVisitor& visitor)
return true; return true;
} }
shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::LogicalAnd::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::LogicalAnd>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<v1::LogicalAnd>(new_args.at(0), new_args.at(1), this->get_autob());
...@@ -57,7 +57,7 @@ bool op::v0::And::visit_attributes(AttributeVisitor& visitor) ...@@ -57,7 +57,7 @@ bool op::v0::And::visit_attributes(AttributeVisitor& visitor)
return true; return true;
} }
shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::And::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v0::And>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<v0::And>(new_args.at(0), new_args.at(1), this->get_autob());
......
...@@ -51,7 +51,8 @@ namespace ngraph ...@@ -51,7 +51,8 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
}; };
...@@ -82,7 +83,8 @@ namespace ngraph ...@@ -82,7 +83,8 @@ namespace ngraph
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
}; };
......
...@@ -34,7 +34,7 @@ op::Any::Any(const Output<Node>& arg, const Output<Node>& reduction_axes) ...@@ -34,7 +34,7 @@ op::Any::Any(const Output<Node>& arg, const Output<Node>& reduction_axes)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Any::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Any::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Any>(new_args.at(0), new_args.at(1)); return make_shared<Any>(new_args.at(0), new_args.at(1));
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
Any(const Output<Node>& arg, const Output<Node>& reduction_axes); Any(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; } bool visit_attributes(AttributeVisitor& visitor) override { return true; }
/// \return The default value for Any. /// \return The default value for Any.
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
......
...@@ -34,9 +34,8 @@ bool op::ArgMax::visit_attributes(AttributeVisitor& visitor) ...@@ -34,9 +34,8 @@ bool op::ArgMax::visit_attributes(AttributeVisitor& visitor)
return true; return true;
} }
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMax::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args);
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type()); return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
} }
......
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
const element::Type& index_element_type); const element::Type& index_element_type);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
......
...@@ -34,7 +34,7 @@ bool op::ArgMin::visit_attributes(AttributeVisitor& visitor) ...@@ -34,7 +34,7 @@ bool op::ArgMin::visit_attributes(AttributeVisitor& visitor)
return true; return true;
} }
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMin::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type()); return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type());
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
const element::Type& index_element_type); const element::Type& index_element_type);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
......
...@@ -39,7 +39,7 @@ op::Asin::Asin(const Output<Node>& arg) ...@@ -39,7 +39,7 @@ op::Asin::Asin(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Asin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Asin::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Asin>(new_args.at(0)); return make_shared<Asin>(new_args.at(0));
......
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
Asin(const Output<Node>& arg); Asin(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; } bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -38,7 +38,7 @@ op::Atan::Atan(const Output<Node>& arg) ...@@ -38,7 +38,7 @@ op::Atan::Atan(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Atan::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Atan::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Atan>(new_args.at(0)); return make_shared<Atan>(new_args.at(0));
......
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
Atan(const Output<Node>& arg); Atan(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; } bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -33,7 +33,7 @@ op::v0::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBro ...@@ -33,7 +33,7 @@ op::v0::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBro
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::v0::Atan2::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Atan2::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob());
......
...@@ -45,7 +45,8 @@ namespace ngraph ...@@ -45,7 +45,8 @@ namespace ngraph
Atan2(const Output<Node>& y, Atan2(const Output<Node>& y,
const Output<Node>& x, const Output<Node>& x,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
protected: protected:
......
...@@ -228,7 +228,7 @@ void op::v0::AvgPool::set_ceil_mode(bool ceil_mode) ...@@ -228,7 +228,7 @@ void op::v0::AvgPool::set_ceil_mode(bool ceil_mode)
m_ceil_mode = ceil_mode; m_ceil_mode = ceil_mode;
} }
shared_ptr<Node> op::v0::AvgPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::AvgPool::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v0::AvgPool>(new_args.at(0), return make_shared<v0::AvgPool>(new_args.at(0),
...@@ -372,7 +372,7 @@ void op::v0::AvgPoolBackprop::set_include_padding_in_avg_computation( ...@@ -372,7 +372,7 @@ void op::v0::AvgPoolBackprop::set_include_padding_in_avg_computation(
m_include_padding_in_avg_computation = include_padding_in_avg_computation; m_include_padding_in_avg_computation = include_padding_in_avg_computation;
} }
shared_ptr<Node> op::v0::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::AvgPoolBackprop::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v0::AvgPoolBackprop>(m_forward_arg_shape, return make_shared<v0::AvgPoolBackprop>(m_forward_arg_shape,
...@@ -581,7 +581,7 @@ void op::v1::AvgPool::set_rounding_type(op::RoundingType rounding_type) ...@@ -581,7 +581,7 @@ void op::v1::AvgPool::set_rounding_type(op::RoundingType rounding_type)
m_rounding_type = rounding_type; m_rounding_type = rounding_type;
} }
shared_ptr<Node> op::v1::AvgPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::AvgPool::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::AvgPool>(new_args.at(0), return make_shared<v1::AvgPool>(new_args.at(0),
...@@ -716,7 +716,7 @@ void op::v1::AvgPoolBackprop::set_exclude_pad(bool exclude_pad) ...@@ -716,7 +716,7 @@ void op::v1::AvgPoolBackprop::set_exclude_pad(bool exclude_pad)
m_exclude_pad = exclude_pad; m_exclude_pad = exclude_pad;
} }
shared_ptr<Node> op::v1::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::AvgPoolBackprop::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::AvgPoolBackprop>(new_args.at(0), return make_shared<v1::AvgPoolBackprop>(new_args.at(0),
......
...@@ -134,7 +134,7 @@ namespace ngraph ...@@ -134,7 +134,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override; const OutputVector& deltas) override;
...@@ -190,7 +190,7 @@ namespace ngraph ...@@ -190,7 +190,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
const Shape& get_forward_arg_shape() const; const Shape& get_forward_arg_shape() const;
void set_forward_arg_shape(const Shape& forward_arg_shape); void set_forward_arg_shape(const Shape& forward_arg_shape);
...@@ -284,7 +284,7 @@ namespace ngraph ...@@ -284,7 +284,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override; const OutputVector& deltas) override;
...@@ -340,7 +340,7 @@ namespace ngraph ...@@ -340,7 +340,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
const Shape get_forward_arg_shape() const; const Shape get_forward_arg_shape() const;
const Shape& get_kernel() const; const Shape& get_kernel() const;
......
...@@ -45,13 +45,13 @@ bool check_binary() ...@@ -45,13 +45,13 @@ bool check_binary()
Shape shape{1}; Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape); auto arg0 = make_shared<op::Parameter>(element::f32, shape);
auto arg1 = make_shared<op::Parameter>(element::f32, shape); auto arg1 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape), OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)}; make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<OP>(arg0, arg1); auto node = make_shared<OP>(arg0, arg1);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->copy_with_new_inputs(new_args);
return (nullptr != new_node) && (new_args == new_node->get_arguments()); return (nullptr != new_node) && (new_args == new_node->input_values());
} }
TEST(copy, abs) TEST(copy, abs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment