Unverified Commit a573dce2 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Start finishing cloning with outputs (#4330)

* Start finishing cloning with outputs

* Update src/ngraph/node.cpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Review comments, compilation
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 686574e8
......@@ -104,48 +104,36 @@ std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const
{
shared_ptr<Node> clone;
if (is_type<op::GetOutputElement>(this))
shared_ptr<Node> clone = clone_with_new_inputs(inputs);
for (auto& cdep : control_dependencies)
{
auto& value = inputs.at(0);
clone = make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
clone->add_control_dependency(cdep);
}
else
return clone;
}
std::shared_ptr<Node> Node::copy_with_new_args(const NodeVector& args) const
{
NODE_VALIDATION_CHECK(
this, false, "Internal error: copy_with_new_args not replaced by clone_with_new_inputs");
return nullptr;
}
std::shared_ptr<Node> Node::clone_with_new_inputs(const OutputVector& inputs) const
{
NodeVector args;
for (const Output<Node>& input : inputs)
{
NodeVector args;
for (const Output<Node>& input : inputs)
{
args.push_back(get_output_element(input, false));
}
for (int i = 0; i < inputs.size(); ++i)
{
auto in_val = inputs.at(i);
if (is_type<op::GetOutputElement>(in_val.get_node()))
{
in_val = as_type_ptr<op::GetOutputElement>(in_val.get_node_shared_ptr())
->get_as_output();
}
auto in_index = in_val.get_index();
auto arg = args.at(i);
size_t out_index = 0;
if (is_type<op::GetOutputElement>(arg))
{
out_index = as_type_ptr<op::GetOutputElement>(arg)->get_n();
}
if (in_index != out_index)
{
cerr << "Mismatch in: " << in_index << " arg: " << out_index << endl;
cerr << "ARG: " << *arg << endl;
cerr << "IN: " << *inputs.at(i).get_node() << endl;
cerr << "INV: " << *in_val.get_node() << endl;
cerr << "In node " << *this << endl;
}
}
clone = copy_with_new_args(args);
args.push_back(get_output_element(input, false));
}
for (auto& cdep : control_dependencies)
std::shared_ptr<Node> clone = copy_with_new_args(args);
// Remove the inserted GOEs
for (size_t i = 0; i < inputs.size(); ++i)
{
clone->add_control_dependency(cdep);
if (clone->input_value(i) != inputs.at(i))
{
clone->set_argument(i, inputs.at(i));
}
}
return clone;
}
......
......@@ -405,8 +405,13 @@ namespace ngraph
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
protected:
// Will be replaced with an OutputVector version
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0;
// Will be replaced with clone_with_new_inputs
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const
NGRAPH_DEPRECATED("use copy_with_new_inputs instead");
// TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make
// this pure and remove copy_with_new_args
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const;
public:
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
......
......@@ -29,7 +29,7 @@ op::Abs::Abs(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Abs::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Abs::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Abs>(new_args.at(0));
......
......@@ -45,7 +45,8 @@ namespace ngraph
///
Abs(const Output<Node>& arg);
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -38,7 +38,7 @@ op::Acos::Acos(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Acos::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Acos::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Acos>(new_args.at(0));
......
......@@ -44,7 +44,8 @@ namespace ngraph
///
Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -43,10 +43,10 @@ void op::GetOutputElement::validate_and_infer_types()
set_output_type(0, input(0).get_element_type(), input(0).get_partial_shape());
}
shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::GetOutputElement::clone_with_new_inputs(const OutputVector& inputs) const
{
check_new_args_count(this, new_args);
return make_shared<GetOutputElement>(new_args.at(0), m_n);
auto& value = inputs.at(0);
return make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
}
Output<Node> op::GetOutputElement::get_as_output() const
......
......@@ -42,8 +42,8 @@ namespace ngraph
/// Return the equilent Output<Node>
Output<Node> get_as_output() const;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;
void validate_and_infer_types() override;
/// \return The index of the tuple element to get.
......
......@@ -31,12 +31,12 @@ bool check_unary()
{
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape)};
OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<OP>(arg0);
auto new_node = node->copy_with_new_args(new_args);
auto new_node = node->copy_with_new_inputs(new_args);
return (nullptr != new_node) && (new_args == new_node->get_arguments());
return (nullptr != new_node) && (new_args == new_node->input_values());
}
template <typename OP>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment