Unverified Commit 8ee374fc authored by Yimei Sun's avatar Yimei Sun Committed by GitHub

Replace copy_with_new_args in B op set (#4426)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 9a52ae47
...@@ -74,7 +74,8 @@ void op::BatchNormTraining::validate_and_infer_types() ...@@ -74,7 +74,8 @@ void op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node>
op::BatchNormTraining::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>( return std::make_shared<BatchNormTraining>(
...@@ -165,7 +166,8 @@ void op::BatchNormInference::validate_and_infer_types() ...@@ -165,7 +166,8 @@ void op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
} }
std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node>
op::BatchNormInference::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>( return std::make_shared<BatchNormInference>(
...@@ -258,7 +260,7 @@ void op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -258,7 +260,7 @@ void op::BatchNormTrainingBackprop::validate_and_infer_types()
} }
std::shared_ptr<Node> std::shared_ptr<Node>
op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const op::BatchNormTrainingBackprop::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2), return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
...@@ -77,8 +77,8 @@ namespace ngraph ...@@ -77,8 +77,8 @@ namespace ngraph
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; } void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -144,8 +144,8 @@ namespace ngraph ...@@ -144,8 +144,8 @@ namespace ngraph
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; } void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */, virtual void generate_adjoints(autodiff::Adjoints& /* adjoints */,
...@@ -194,8 +194,8 @@ namespace ngraph ...@@ -194,8 +194,8 @@ namespace ngraph
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; } void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
private: private:
static constexpr size_t INPUT_GAMMA = 0; static constexpr size_t INPUT_GAMMA = 0;
......
...@@ -129,7 +129,8 @@ void op::v1::BinaryConvolution::validate_and_infer_types() ...@@ -129,7 +129,8 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
set_output_type(0, data_batch_et, result_shape); set_output_type(0, data_batch_et, result_shape);
} }
shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node>
op::v1::BinaryConvolution::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::BinaryConvolution>(new_args.at(0), return make_shared<v1::BinaryConvolution>(new_args.at(0),
......
...@@ -76,8 +76,8 @@ namespace ngraph ...@@ -76,8 +76,8 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints, void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override; const OutputVector& deltas) override;
......
...@@ -246,7 +246,7 @@ void op::v1::Broadcast::validate_and_infer_types() ...@@ -246,7 +246,7 @@ void op::v1::Broadcast::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), result_shape); set_output_type(0, get_input_element_type(0), result_shape);
} }
shared_ptr<Node> op::v1::Broadcast::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::Broadcast>( return make_shared<v1::Broadcast>(
...@@ -339,7 +339,7 @@ void op::v0::Broadcast::validate_and_infer_types() ...@@ -339,7 +339,7 @@ void op::v0::Broadcast::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), m_shape); set_output_type(0, get_input_element_type(0), m_shape);
} }
shared_ptr<Node> op::v0::Broadcast::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v0::Broadcast>(new_args.at(0), m_shape, m_broadcast_axes); return make_shared<v0::Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
...@@ -373,7 +373,7 @@ bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor) ...@@ -373,7 +373,7 @@ bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
return true; return true;
} }
shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::BroadcastLike::clone_with_new_inputs(const OutputVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{ {
......
...@@ -48,8 +48,8 @@ namespace ngraph ...@@ -48,8 +48,8 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return A set containing the indices of the broadcast axes (0-based). /// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
...@@ -93,8 +93,8 @@ namespace ngraph ...@@ -93,8 +93,8 @@ namespace ngraph
const Output<Node>& like_arg, const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes); const AxisSet& initial_broadcast_axes);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
void infer_shape() override; void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const const AxisSet& get_initial_broadcast_axes() const
...@@ -155,8 +155,8 @@ namespace ngraph ...@@ -155,8 +155,8 @@ namespace ngraph
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return Broadcast Specification. /// \return Broadcast Specification.
const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; } const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
......
...@@ -48,7 +48,7 @@ void op::BroadcastDistributed::validate_and_infer_types() ...@@ -48,7 +48,7 @@ void op::BroadcastDistributed::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::BroadcastDistributed::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id); return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id);
......
...@@ -36,8 +36,8 @@ namespace ngraph ...@@ -36,8 +36,8 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_root_id() const; int64_t get_root_id() const;
void set_root_id(int64_t root_id); void set_root_id(int64_t root_id);
......
...@@ -88,18 +88,18 @@ TEST(copy, broadcast) ...@@ -88,18 +88,18 @@ TEST(copy, broadcast)
{ {
Shape shape1{1}; Shape shape1{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape1); auto arg0 = make_shared<op::Parameter>(element::f32, shape1);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape1)}; OutputVector new_args{make_shared<op::Parameter>(element::f32, shape1)};
Shape shape{4, 1, 3}; Shape shape{4, 1, 3};
AxisSet axes{0, 2}; AxisSet axes{0, 2};
auto node = make_shared<op::Broadcast>(arg0, shape, axes); auto node = make_shared<op::Broadcast>(arg0, shape, axes);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->copy_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Broadcast>(new_node); auto node_cast = as_type_ptr<op::Broadcast>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == new_node->input_values());
ASSERT_TRUE(shape == node_cast->get_broadcast_shape()); ASSERT_TRUE(shape == node_cast->get_broadcast_shape());
ASSERT_TRUE(axes == node_cast->get_broadcast_axes()); ASSERT_TRUE(axes == node_cast->get_broadcast_axes());
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment