Unverified Commit a573dce2 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Start finishing cloning with outputs (#4330)

* Start finishing cloning with outputs

* Update src/ngraph/node.cpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Review comments, compilation
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 686574e8
...@@ -104,48 +104,36 @@ std::shared_ptr<Node> ...@@ -104,48 +104,36 @@ std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs, Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const const std::vector<std::shared_ptr<Node>>& control_dependencies) const
{ {
shared_ptr<Node> clone; shared_ptr<Node> clone = clone_with_new_inputs(inputs);
if (is_type<op::GetOutputElement>(this)) for (auto& cdep : control_dependencies)
{ {
auto& value = inputs.at(0); clone->add_control_dependency(cdep);
clone = make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
} }
else return clone;
}
std::shared_ptr<Node> Node::copy_with_new_args(const NodeVector& args) const
{
NODE_VALIDATION_CHECK(
this, false, "Internal error: copy_with_new_args not replaced by clone_with_new_inputs");
return nullptr;
}
std::shared_ptr<Node> Node::clone_with_new_inputs(const OutputVector& inputs) const
{
NodeVector args;
for (const Output<Node>& input : inputs)
{ {
NodeVector args; args.push_back(get_output_element(input, false));
for (const Output<Node>& input : inputs)
{
args.push_back(get_output_element(input, false));
}
for (int i = 0; i < inputs.size(); ++i)
{
auto in_val = inputs.at(i);
if (is_type<op::GetOutputElement>(in_val.get_node()))
{
in_val = as_type_ptr<op::GetOutputElement>(in_val.get_node_shared_ptr())
->get_as_output();
}
auto in_index = in_val.get_index();
auto arg = args.at(i);
size_t out_index = 0;
if (is_type<op::GetOutputElement>(arg))
{
out_index = as_type_ptr<op::GetOutputElement>(arg)->get_n();
}
if (in_index != out_index)
{
cerr << "Mismatch in: " << in_index << " arg: " << out_index << endl;
cerr << "ARG: " << *arg << endl;
cerr << "IN: " << *inputs.at(i).get_node() << endl;
cerr << "INV: " << *in_val.get_node() << endl;
cerr << "In node " << *this << endl;
}
}
clone = copy_with_new_args(args);
} }
for (auto& cdep : control_dependencies) std::shared_ptr<Node> clone = copy_with_new_args(args);
// Remove the inserted GOEs
for (size_t i = 0; i < inputs.size(); ++i)
{ {
clone->add_control_dependency(cdep); if (clone->input_value(i) != inputs.at(i))
{
clone->set_argument(i, inputs.at(i));
}
} }
return clone; return clone;
} }
......
...@@ -405,8 +405,13 @@ namespace ngraph ...@@ -405,8 +405,13 @@ namespace ngraph
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const; std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
protected: protected:
// Will be replaced with an OutputVector version // Will be replaced with clone_with_new_inputs
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0; virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const
NGRAPH_DEPRECATED("use copy_with_new_inputs instead");
// TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make
// this pure and remove copy_with_new_args
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const;
public: public:
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const; std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
......
...@@ -29,7 +29,7 @@ op::Abs::Abs(const Output<Node>& arg) ...@@ -29,7 +29,7 @@ op::Abs::Abs(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Abs::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Abs::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Abs>(new_args.at(0)); return make_shared<Abs>(new_args.at(0));
......
...@@ -45,7 +45,8 @@ namespace ngraph ...@@ -45,7 +45,8 @@ namespace ngraph
/// ///
Abs(const Output<Node>& arg); Abs(const Output<Node>& arg);
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -38,7 +38,7 @@ op::Acos::Acos(const Output<Node>& arg) ...@@ -38,7 +38,7 @@ op::Acos::Acos(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Acos::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Acos::clone_with_new_inputs(const OutputVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Acos>(new_args.at(0)); return make_shared<Acos>(new_args.at(0));
......
...@@ -44,7 +44,8 @@ namespace ngraph ...@@ -44,7 +44,8 @@ namespace ngraph
/// ///
Acos(const Output<Node>& arg); Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override { return true; } bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -43,10 +43,10 @@ void op::GetOutputElement::validate_and_infer_types() ...@@ -43,10 +43,10 @@ void op::GetOutputElement::validate_and_infer_types()
set_output_type(0, input(0).get_element_type(), input(0).get_partial_shape()); set_output_type(0, input(0).get_element_type(), input(0).get_partial_shape());
} }
shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GetOutputElement::clone_with_new_inputs(const OutputVector& inputs) const
{ {
check_new_args_count(this, new_args); auto& value = inputs.at(0);
return make_shared<GetOutputElement>(new_args.at(0), m_n); return make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
} }
Output<Node> op::GetOutputElement::get_as_output() const Output<Node> op::GetOutputElement::get_as_output() const
......
...@@ -42,8 +42,8 @@ namespace ngraph ...@@ -42,8 +42,8 @@ namespace ngraph
/// Return the equilent Output<Node> /// Return the equilent Output<Node>
Output<Node> get_as_output() const; Output<Node> get_as_output() const;
virtual std::shared_ptr<Node> std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; clone_with_new_inputs(const OutputVector& inputs) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
......
...@@ -31,12 +31,12 @@ bool check_unary() ...@@ -31,12 +31,12 @@ bool check_unary()
{ {
Shape shape{1}; Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::f32, shape); auto arg0 = make_shared<op::Parameter>(element::f32, shape);
NodeVector new_args{make_shared<op::Parameter>(element::f32, shape)}; OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
auto node = make_shared<OP>(arg0); auto node = make_shared<OP>(arg0);
auto new_node = node->copy_with_new_args(new_args); auto new_node = node->copy_with_new_inputs(new_args);
return (nullptr != new_node) && (new_args == new_node->get_arguments()); return (nullptr != new_node) && (new_args == new_node->input_values());
} }
template <typename OP> template <typename OP>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment