Unverified Commit 8017c094 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Update calls to copy_with_new_args to clone_with_new_inputs (#4452)

* more fixes

* Fix test

* Cleanup comment

* style

* Update per review comments

* Change get_arguments to input_values
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 0af33226
...@@ -60,10 +60,10 @@ void descriptor::Input::replace_output(Output& new_output) ...@@ -60,10 +60,10 @@ void descriptor::Input::replace_output(Output& new_output)
if (getenv_bool("NGRAPH_ENABLE_REPLACE_CHECK")) if (getenv_bool("NGRAPH_ENABLE_REPLACE_CHECK"))
{ {
// the result of copy_with_new_args will be thrown away or // the result of clone_with_new_inputs will be thrown away or
// an exception will be thrown by `m_node`'s class c-tor // an exception will be thrown by `m_node`'s class c-tor
// if a new input violates one of the type checks in the c-tor. // if a new input violates one of the type checks in the c-tor.
(this->m_node->copy_with_new_args(this->m_node->get_arguments())); m_node->clone_with_new_inputs(m_node->input_values());
} }
} }
......
...@@ -215,7 +215,7 @@ namespace ngraph ...@@ -215,7 +215,7 @@ namespace ngraph
/// ///
/// To avoid the cycle, a valid way to perform the above desired insertion would be, /// To avoid the cycle, a valid way to perform the above desired insertion would be,
/// ///
/// auto new_N = N->copy_with_new_args(N->get_arguments()); /// auto new_N = N->clone_with_new_inputs(N->input_values());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N); /// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M); /// replace_node(N, M);
NGRAPH_API NGRAPH_API
......
...@@ -428,11 +428,11 @@ namespace ngraph ...@@ -428,11 +428,11 @@ namespace ngraph
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const
NGRAPH_DEPRECATED("use copy_with_new_inputs instead"); NGRAPH_DEPRECATED("use copy_with_new_inputs instead");
public:
// TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make // TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make
// this pure and remove copy_with_new_args // this pure and remove copy_with_new_args
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const; virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const;
public:
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const; std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
std::shared_ptr<Node> copy_with_new_inputs( std::shared_ptr<Node> copy_with_new_inputs(
......
...@@ -165,14 +165,15 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -165,14 +165,15 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
{ {
auto csw = work_queue.front(); auto csw = work_queue.front();
work_queue.pop_front(); work_queue.pop_front();
auto n = csw.input.get_source_output().get_node_shared_ptr(); auto n_output = csw.input.get_source_output();
auto materialize = [csw, n]() { auto n = n_output.get_node_shared_ptr();
auto new_reshape = csw.reshape->copy_with_new_args({n}); auto materialize = [csw, n_output]() {
auto n = n_output.get_node_shared_ptr();
auto new_reshape = csw.reshape->clone_with_new_inputs({n});
new_reshape->merge_provenance_tags_from(n); new_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape); NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0)); csw.input.replace_source_output(new_reshape->output(0));
}; }; // Only swim past nodes which have a single user
// Only swim past nodes which have a single user
if (n->get_users().size() > 1) if (n->get_users().size() > 1)
{ {
materialize(); materialize();
......
...@@ -115,7 +115,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f) ...@@ -115,7 +115,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
if (auto concat = as_type_ptr<op::Concat>(n)) if (auto concat = as_type_ptr<op::Concat>(n))
{ {
NodeVector non_zero_dim_args; OutputVector non_zero_dim_args;
for (auto arg : concat->get_arguments()) for (auto arg : concat->get_arguments())
{ {
if (!has_zero_dim(arg)) if (!has_zero_dim(arg))
...@@ -126,7 +126,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f) ...@@ -126,7 +126,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
if (non_zero_dim_args.size() < concat->get_input_size()) if (non_zero_dim_args.size() < concat->get_input_size())
{ {
auto new_concat = concat->copy_with_new_args(non_zero_dim_args); auto new_concat = concat->clone_with_new_inputs(non_zero_dim_args);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " NGRAPH_DEBUG << " Replacing " << n->get_name() << " with "
<< new_concat->get_name(); << new_concat->get_name();
replace_node(concat, new_concat); replace_node(concat, new_concat);
......
...@@ -1015,7 +1015,7 @@ TEST(constant, shared_data) ...@@ -1015,7 +1015,7 @@ TEST(constant, shared_data)
{ {
Shape shape{100, 200}; Shape shape{100, 200};
auto c1 = make_shared<op::Constant>(element::f16, shape, vector<float16>{123}); auto c1 = make_shared<op::Constant>(element::f16, shape, vector<float16>{123});
auto c2 = static_pointer_cast<op::Constant>(c1->copy_with_new_args({})); auto c2 = static_pointer_cast<op::Constant>(c1->clone_with_new_inputs({}));
const float* p1 = c1->get_data_ptr<float>(); const float* p1 = c1->get_data_ptr<float>();
const float* p2 = c2->get_data_ptr<float>(); const float* p2 = c2->get_data_ptr<float>();
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
......
...@@ -114,16 +114,16 @@ TEST(copy, concat) ...@@ -114,16 +114,16 @@ TEST(copy, concat)
Shape shape{1}; Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape); auto arg0 = make_shared<op::Parameter>(element::f32, shape);
auto arg1 = make_shared<op::Parameter>(element::f32, shape); auto arg1 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape), OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)}; make_shared<op::Parameter>(element::f32, shape)};
size_t axis = 0; size_t axis = 0;
auto node = make_shared<op::Concat>(NodeVector{arg0, arg1}, axis); auto node = make_shared<op::Concat>(NodeVector{arg0, arg1}, axis);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Concat>(new_node); auto node_cast = as_type_ptr<op::Concat>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == new_node->input_values());
ASSERT_TRUE(node_cast->get_concatenation_axis() == axis); ASSERT_TRUE(node_cast->get_concatenation_axis() == axis);
} }
...@@ -133,7 +133,7 @@ TEST(copy, constant) ...@@ -133,7 +133,7 @@ TEST(copy, constant)
vector<float> c{2.4f}; vector<float> c{2.4f};
auto& et = element::f32; auto& et = element::f32;
auto node = op::Constant::create(et, shape, c); auto node = op::Constant::create(et, shape, c);
auto new_node = node->copy_with_new_args(NodeVector{}); auto new_node = node->clone_with_new_inputs(OutputVector{});
auto node_cast = as_type_ptr<op::Constant>(new_node); auto node_cast = as_type_ptr<op::Constant>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
...@@ -148,15 +148,15 @@ TEST(copy, convert) ...@@ -148,15 +148,15 @@ TEST(copy, convert)
Shape shape; Shape shape;
auto& et = element::f64; auto& et = element::f64;
auto arg0 = make_shared<op::Parameter>(element::f32, shape); auto arg0 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape)}; OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<op::Convert>(arg0, et); auto node = make_shared<op::Convert>(arg0, et);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Convert>(new_node); auto node_cast = as_type_ptr<op::Convert>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(et == node_cast->get_convert_element_type()); ASSERT_TRUE(et == node_cast->get_convert_element_type());
} }
...@@ -249,7 +249,7 @@ TEST(copy, parameter) ...@@ -249,7 +249,7 @@ TEST(copy, parameter)
{ {
Shape shape{1}; Shape shape{1};
auto node = make_shared<op::Parameter>(element::f32, shape); auto node = make_shared<op::Parameter>(element::f32, shape);
auto new_node = node->copy_with_new_args({}); auto new_node = node->clone_with_new_inputs({});
auto node_cast = as_type_ptr<op::Parameter>(new_node); auto node_cast = as_type_ptr<op::Parameter>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
...@@ -270,15 +270,15 @@ TEST(copy, reshape) ...@@ -270,15 +270,15 @@ TEST(copy, reshape)
Shape shape_out{6, 4}; Shape shape_out{6, 4};
auto arg0 = make_shared<op::Parameter>(element::f32, shape_in); auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape_in)}; OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
auto node = make_shared<op::Reshape>(arg0, axes, shape_out); auto node = make_shared<op::Reshape>(arg0, axes, shape_out);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Reshape>(new_node); auto node_cast = as_type_ptr<op::Reshape>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(axes == node_cast->get_input_order()); ASSERT_TRUE(axes == node_cast->get_input_order());
ASSERT_TRUE(shape_out == node_cast->get_output_shape(0)); ASSERT_TRUE(shape_out == node_cast->get_output_shape(0));
} }
...@@ -289,17 +289,17 @@ TEST(copy, select) ...@@ -289,17 +289,17 @@ TEST(copy, select)
auto arg0 = make_shared<op::Parameter>(element::boolean, shape); auto arg0 = make_shared<op::Parameter>(element::boolean, shape);
auto arg1 = make_shared<op::Parameter>(element::f32, shape); auto arg1 = make_shared<op::Parameter>(element::f32, shape);
auto arg2 = make_shared<op::Parameter>(element::f32, shape); auto arg2 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::boolean, shape), OutputVector new_args{make_shared<op::Parameter>(element::boolean, shape),
make_shared<op::Parameter>(element::f32, shape), make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)}; make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<op::Select>(arg0, arg1, arg2); auto node = make_shared<op::Select>(arg0, arg1, arg2);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Select>(new_node); auto node_cast = as_type_ptr<op::Select>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == new_node->input_values());
} }
TEST(copy, sign) TEST(copy, sign)
...@@ -325,15 +325,15 @@ TEST(copy, slice) ...@@ -325,15 +325,15 @@ TEST(copy, slice)
Strides strides{1, 1, 1}; Strides strides{1, 1, 1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape_in); auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape_in)}; OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
auto node = make_shared<op::Slice>(arg0, lower, upper, strides); auto node = make_shared<op::Slice>(arg0, lower, upper, strides);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Slice>(new_node); auto node_cast = as_type_ptr<op::Slice>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(lower == node_cast->get_lower_bounds()); ASSERT_TRUE(lower == node_cast->get_lower_bounds());
ASSERT_TRUE(upper == node_cast->get_upper_bounds()); ASSERT_TRUE(upper == node_cast->get_upper_bounds());
ASSERT_TRUE(strides == node_cast->get_strides()); ASSERT_TRUE(strides == node_cast->get_strides());
...@@ -351,13 +351,13 @@ TEST(copy, sum) ...@@ -351,13 +351,13 @@ TEST(copy, sum)
auto arg0 = make_shared<op::Parameter>(element::f32, shape); auto arg0 = make_shared<op::Parameter>(element::f32, shape);
auto node = make_shared<op::Sum>(arg0, axes); auto node = make_shared<op::Sum>(arg0, axes);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)}; OutputVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)};
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Sum>(new_node); auto node_cast = as_type_ptr<op::Sum>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(axes == node_cast->get_reduction_axes()); ASSERT_TRUE(axes == node_cast->get_reduction_axes());
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment