Unverified Commit 0c523507 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Some GetOutputElement changes to help with Output<Node> (#3306)

* Some GetOutputElement changes to help with Output<Node>

* Review comments
parent 1d5d2024
......@@ -293,9 +293,9 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
return std::make_shared<ngraph::Function>(cloned_results, cloned_params);
}
bool ngraph::is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant)
bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant)
{
if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant))
if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{
auto cshape = rc->get_shape();
size_t n = shape_size(cshape);
......@@ -454,7 +454,7 @@ std::shared_ptr<Node> ngraph::make_constant_from_string(std::string val,
return std::make_shared<op::Constant>(element_type, shape, cvals);
}
bool ngraph::is_zero(std::shared_ptr<Node> reduce_constant)
bool ngraph::is_zero(const Output<Node>& reduce_constant)
{
auto result_bool = is_equal_to_const_value("0", reduce_constant);
return result_bool;
......
......@@ -349,7 +349,7 @@ namespace ngraph
// Check if all paths from X to a result go through Y
bool is_post_dominated(Node* X, Node* Y);
bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant);
bool is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant);
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
......@@ -383,7 +383,7 @@ namespace ngraph
const element::Type& element_type,
const Shape& shape);
bool is_zero(std::shared_ptr<Node> reduce_constant);
bool is_zero(const Output<Node>& reduce_constant);
NodeVector get_subgraph_outputs(const NodeVector& nodes,
const NodeVector& exclusions,
......
......@@ -86,6 +86,18 @@ std::shared_ptr<Node> Node::copy_with_new_inputs(const OutputVector& inputs) con
return copy_with_new_inputs(inputs, get_control_dependencies());
}
std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i)
{
for (auto in : output(i).get_target_inputs())
{
if (in.get_node()->description() == op::GetOutputElement::type_name)
{
return in.get_node()->shared_from_this();
}
}
return get_output_element(output(i), true);
}
std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const
......
......@@ -279,6 +279,8 @@ namespace ngraph
/// Returns the partial shape for output i
const PartialShape& get_output_partial_shape(size_t i) const;
std::shared_ptr<Node> get_output_as_single_output_node(size_t i);
/// Checks that there is exactly one output and returns its shape
// TODO: deprecate in favor of node->output(0).get_shape() with a suitable check in the
// calling code, or updates to the calling code if it is making an invalid assumption of
......@@ -554,6 +556,13 @@ namespace ngraph
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node() const NGRAPH_DEPRECATED("Transitional.")
{
return m_node->get_output_as_single_output_node(m_index);
}
/// \return The index of the output referred to by this output handle.
size_t get_index() const { return m_index; }
/// \return A reference to the tensor descriptor for this output.
......
......@@ -71,20 +71,7 @@ NodeVector op::get_output_elements(const shared_ptr<Node>& mon)
NodeVector goes(mon->get_output_size());
for (auto o : mon->outputs())
{
shared_ptr<Node> goe;
for (auto in : o.get_target_inputs())
{
if (in.get_node()->description() == op::GetOutputElement::type_name)
{
goe = in.get_node()->shared_from_this();
break;
}
}
if (goe == nullptr)
{
goe = get_output_element(o, true);
}
goes.at(o.get_index()) = goe;
goes.at(o.get_index()) = o.as_single_output_node();
}
return goes;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment