Unverified Commit 9560ffa1 authored by Yimei Sun's avatar Yimei Sun Committed by GitHub

Replace copy_with_new_args in A set of ops (#4424)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d6be21dc
......@@ -31,7 +31,7 @@ op::v0::Add::Add(const Output<Node>& arg0,
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Add::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob());
......@@ -82,7 +82,7 @@ bool op::v1::Add::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Add::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Add>(new_args.at(0), new_args.at(1), this->get_autob());
......
......@@ -53,7 +53,8 @@ namespace ngraph
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
......@@ -94,7 +95,8 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
......
......@@ -34,7 +34,7 @@ op::All::All(const Output<Node>& arg, const Output<Node>& reduction_axes)
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::All::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::All::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<All>(new_args.at(0), new_args.at(1));
......
......@@ -43,7 +43,8 @@ namespace ngraph
/// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const Output<Node>& reduction_axes);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return The default value for All.
virtual std::shared_ptr<Node> get_default_value() const override;
......
......@@ -43,7 +43,7 @@ void op::AllReduce::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::AllReduce::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<AllReduce>(new_args.at(0), get_reduce_type());
......
......@@ -37,7 +37,8 @@ namespace ngraph
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type);
bool visit_attributes(AttributeVisitor& visitor) override;
......
......@@ -35,7 +35,7 @@ bool op::v1::LogicalAnd::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::LogicalAnd::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::LogicalAnd>(new_args.at(0), new_args.at(1), this->get_autob());
......@@ -57,7 +57,7 @@ bool op::v0::And::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::And::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::And>(new_args.at(0), new_args.at(1), this->get_autob());
......
......@@ -51,7 +51,8 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
};
......@@ -82,7 +83,8 @@ namespace ngraph
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
};
......
......@@ -34,7 +34,7 @@ op::Any::Any(const Output<Node>& arg, const Output<Node>& reduction_axes)
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Any::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Any::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Any>(new_args.at(0), new_args.at(1));
......
......@@ -44,7 +44,7 @@ namespace ngraph
Any(const Output<Node>& arg, const Output<Node>& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
/// \return The default value for Any.
virtual std::shared_ptr<Node> get_default_value() const override;
......
......@@ -34,9 +34,8 @@ bool op::ArgMax::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::ArgMax::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
check_new_args_count(this, new_args);
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
}
......
......@@ -43,7 +43,7 @@ namespace ngraph
const element::Type& index_element_type);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override;
};
......
......@@ -34,7 +34,7 @@ bool op::ArgMin::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::ArgMin::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type());
......
......@@ -44,7 +44,7 @@ namespace ngraph
const element::Type& index_element_type);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override;
};
......
......@@ -39,7 +39,7 @@ op::Asin::Asin(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Asin::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Asin::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Asin>(new_args.at(0));
......
......@@ -45,7 +45,7 @@ namespace ngraph
Asin(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -38,7 +38,7 @@ op::Atan::Atan(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Atan::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Atan::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Atan>(new_args.at(0));
......
......@@ -46,7 +46,7 @@ namespace ngraph
Atan(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -33,7 +33,7 @@ op::v0::Atan2::Atan2(const Output<Node>& y, const Output<Node>& x, const AutoBro
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Atan2::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Atan2::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Atan2>(new_args.at(0), new_args.at(1), this->get_autob());
......
......@@ -45,7 +45,8 @@ namespace ngraph
Atan2(const Output<Node>& y,
const Output<Node>& x,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
protected:
......
......@@ -228,7 +228,7 @@ void op::v0::AvgPool::set_ceil_mode(bool ceil_mode)
m_ceil_mode = ceil_mode;
}
shared_ptr<Node> op::v0::AvgPool::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::AvgPool::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::AvgPool>(new_args.at(0),
......@@ -372,7 +372,7 @@ void op::v0::AvgPoolBackprop::set_include_padding_in_avg_computation(
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
shared_ptr<Node> op::v0::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::AvgPoolBackprop::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::AvgPoolBackprop>(m_forward_arg_shape,
......@@ -581,7 +581,7 @@ void op::v1::AvgPool::set_rounding_type(op::RoundingType rounding_type)
m_rounding_type = rounding_type;
}
shared_ptr<Node> op::v1::AvgPool::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::AvgPool::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::AvgPool>(new_args.at(0),
......@@ -716,7 +716,7 @@ void op::v1::AvgPoolBackprop::set_exclude_pad(bool exclude_pad)
m_exclude_pad = exclude_pad;
}
shared_ptr<Node> op::v1::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::AvgPoolBackprop::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::AvgPoolBackprop>(new_args.at(0),
......
......@@ -134,7 +134,7 @@ namespace ngraph
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
......@@ -190,7 +190,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
const Shape& get_forward_arg_shape() const;
void set_forward_arg_shape(const Shape& forward_arg_shape);
......@@ -284,7 +284,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
......@@ -340,7 +340,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
const Shape get_forward_arg_shape() const;
const Shape& get_kernel() const;
......
......@@ -45,13 +45,13 @@ bool check_binary()
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape);
auto arg1 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)};
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<OP>(arg0, arg1);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->copy_with_new_inputs(new_args);
return (nullptr != new_node) && (new_args == new_node->get_arguments());
return (nullptr != new_node) && (new_args == new_node->input_values());
}
TEST(copy, abs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment