Unverified Commit 8ee374fc authored by Yimei Sun's avatar Yimei Sun Committed by GitHub

Replace copy_with_new_args in B op set (#4426)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 9a52ae47
......@@ -74,7 +74,8 @@ void op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape);
}
std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node>
op::BatchNormTraining::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>(
......@@ -165,7 +166,8 @@ void op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape);
}
std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node>
op::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>(
......@@ -258,7 +260,7 @@ void op::BatchNormTrainingBackprop::validate_and_infer_types()
}
std::shared_ptr<Node>
op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
op::BatchNormTrainingBackprop::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
......@@ -77,8 +77,8 @@ namespace ngraph
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......@@ -144,8 +144,8 @@ namespace ngraph
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */,
......@@ -194,8 +194,8 @@ namespace ngraph
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
private:
static constexpr size_t INPUT_GAMMA = 0;
......
......@@ -129,7 +129,8 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
set_output_type(0, data_batch_et, result_shape);
}
shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node>
op::v1::BinaryConvolution::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::BinaryConvolution>(new_args.at(0),
......
......@@ -76,8 +76,8 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
......
......@@ -246,7 +246,7 @@ void op::v1::Broadcast::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), result_shape);
}
shared_ptr<Node> op::v1::Broadcast::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Broadcast>(
......@@ -339,7 +339,7 @@ void op::v0::Broadcast::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), m_shape);
}
shared_ptr<Node> op::v0::Broadcast::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
......@@ -373,7 +373,7 @@ bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::BroadcastLike::clone_with_new_inputs(const OutputVector& new_args) const
{
if (new_args.size() != 2)
{
......
......@@ -48,8 +48,8 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
......@@ -93,8 +93,8 @@ namespace ngraph
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const
......@@ -155,8 +155,8 @@ namespace ngraph
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
......
......@@ -48,7 +48,7 @@ void op::BroadcastDistributed::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::BroadcastDistributed::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id);
......
......@@ -36,8 +36,8 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_root_id() const;
void set_root_id(int64_t root_id);
......
......@@ -88,18 +88,18 @@ TEST(copy, broadcast)
{
Shape shape1{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape1);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape1)};
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape1)};
Shape shape{4, 1, 3};
AxisSet axes{0, 2};
auto node = make_shared<op::Broadcast>(arg0, shape, axes);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->copy_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Broadcast>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == new_node->input_values());
ASSERT_TRUE(shape == node_cast->get_broadcast_shape());
ASSERT_TRUE(axes == node_cast->get_broadcast_axes());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment