Unverified Commit 8017c094 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Update calls to copy_with_new_args to clone_with_new_inputs (#4452)

* more fixes

* Fix test

* Cleanup comment

* style

* Update per review comments

* Change get_arguments to input_values
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 0af33226
......@@ -60,10 +60,10 @@ void descriptor::Input::replace_output(Output& new_output)
if (getenv_bool("NGRAPH_ENABLE_REPLACE_CHECK"))
{
// the result of copy_with_new_args will be thrown away or
// the result of clone_with_new_inputs will be thrown away or
// an exception will be thrown by `m_node`'s class c-tor
// if a new input violates one of the type checks in the c-tor.
(this->m_node->copy_with_new_args(this->m_node->get_arguments()));
m_node->clone_with_new_inputs(m_node->input_values());
}
}
......
......@@ -215,7 +215,7 @@ namespace ngraph
///
/// To avoid the cycle, a valid way to perform the above desired insertion would be,
///
/// auto new_N = N->copy_with_new_args(N->get_arguments());
/// auto new_N = N->clone_with_new_inputs(N->input_values());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M);
NGRAPH_API
......
......@@ -428,11 +428,11 @@ namespace ngraph
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const
NGRAPH_DEPRECATED("use copy_with_new_inputs instead");
public:
// TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make
// this pure and remove copy_with_new_args
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const;
public:
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
std::shared_ptr<Node> copy_with_new_inputs(
......
......@@ -165,14 +165,15 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
{
auto csw = work_queue.front();
work_queue.pop_front();
auto n = csw.input.get_source_output().get_node_shared_ptr();
auto materialize = [csw, n]() {
auto new_reshape = csw.reshape->copy_with_new_args({n});
auto n_output = csw.input.get_source_output();
auto n = n_output.get_node_shared_ptr();
auto materialize = [csw, n_output]() {
auto n = n_output.get_node_shared_ptr();
auto new_reshape = csw.reshape->clone_with_new_inputs({n});
new_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0));
};
// Only swim past nodes which have a single user
}; // Only swim past nodes which have a single user
if (n->get_users().size() > 1)
{
materialize();
......
......@@ -115,7 +115,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
if (auto concat = as_type_ptr<op::Concat>(n))
{
NodeVector non_zero_dim_args;
OutputVector non_zero_dim_args;
for (auto arg : concat->get_arguments())
{
if (!has_zero_dim(arg))
......@@ -126,7 +126,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
if (non_zero_dim_args.size() < concat->get_input_size())
{
auto new_concat = concat->copy_with_new_args(non_zero_dim_args);
auto new_concat = concat->clone_with_new_inputs(non_zero_dim_args);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with "
<< new_concat->get_name();
replace_node(concat, new_concat);
......
......@@ -1015,7 +1015,7 @@ TEST(constant, shared_data)
{
Shape shape{100, 200};
auto c1 = make_shared<op::Constant>(element::f16, shape, vector<float16>{123});
auto c2 = static_pointer_cast<op::Constant>(c1->copy_with_new_args({}));
auto c2 = static_pointer_cast<op::Constant>(c1->clone_with_new_inputs({}));
const float* p1 = c1->get_data_ptr<float>();
const float* p2 = c2->get_data_ptr<float>();
EXPECT_EQ(p1, p2);
......
......@@ -114,16 +114,16 @@ TEST(copy, concat)
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape);
auto arg1 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape),
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)};
size_t axis = 0;
auto node = make_shared<op::Concat>(NodeVector{arg0, arg1}, axis);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Concat>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == new_node->input_values());
ASSERT_TRUE(node_cast->get_concatenation_axis() == axis);
}
......@@ -133,7 +133,7 @@ TEST(copy, constant)
vector<float> c{2.4f};
auto& et = element::f32;
auto node = op::Constant::create(et, shape, c);
auto new_node = node->copy_with_new_args(NodeVector{});
auto new_node = node->clone_with_new_inputs(OutputVector{});
auto node_cast = as_type_ptr<op::Constant>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
......@@ -148,15 +148,15 @@ TEST(copy, convert)
Shape shape;
auto& et = element::f64;
auto arg0 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape)};
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<op::Convert>(arg0, et);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Convert>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(et == node_cast->get_convert_element_type());
}
......@@ -249,7 +249,7 @@ TEST(copy, parameter)
{
Shape shape{1};
auto node = make_shared<op::Parameter>(element::f32, shape);
auto new_node = node->copy_with_new_args({});
auto new_node = node->clone_with_new_inputs({});
auto node_cast = as_type_ptr<op::Parameter>(new_node);
ASSERT_NE(node_cast, nullptr);
......@@ -270,15 +270,15 @@ TEST(copy, reshape)
Shape shape_out{6, 4};
auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
auto node = make_shared<op::Reshape>(arg0, axes, shape_out);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Reshape>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(axes == node_cast->get_input_order());
ASSERT_TRUE(shape_out == node_cast->get_output_shape(0));
}
......@@ -289,17 +289,17 @@ TEST(copy, select)
auto arg0 = make_shared<op::Parameter>(element::boolean, shape);
auto arg1 = make_shared<op::Parameter>(element::f32, shape);
auto arg2 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::boolean, shape),
OutputVector new_args{make_shared<op::Parameter>(element::boolean, shape),
make_shared<op::Parameter>(element::f32, shape),
make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<op::Select>(arg0, arg1, arg2);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Select>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == new_node->input_values());
}
TEST(copy, sign)
......@@ -325,15 +325,15 @@ TEST(copy, slice)
Strides strides{1, 1, 1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
auto node = make_shared<op::Slice>(arg0, lower, upper, strides);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Slice>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(lower == node_cast->get_lower_bounds());
ASSERT_TRUE(upper == node_cast->get_upper_bounds());
ASSERT_TRUE(strides == node_cast->get_strides());
......@@ -351,13 +351,13 @@ TEST(copy, sum)
auto arg0 = make_shared<op::Parameter>(element::f32, shape);
auto node = make_shared<op::Sum>(arg0, axes);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)};
auto new_node = node->copy_with_new_args(new_args);
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)};
auto new_node = node->clone_with_new_inputs(new_args);
auto node_cast = as_type_ptr<op::Sum>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
ASSERT_TRUE(axes == node_cast->get_reduction_axes());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment