Unverified Commit 8980e2ea authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Switch Output to use a shared_ptr to prevent nodes from disappearing early. (#3109)

parent 52b36eff
...@@ -487,7 +487,7 @@ namespace ngraph ...@@ -487,7 +487,7 @@ namespace ngraph
/// \param node A pointer to the node for the output handle. /// \param node A pointer to the node for the output handle.
/// \param index The index of the output. /// \param index The index of the output.
Output(NodeType* node, size_t index) Output(NodeType* node, size_t index)
: m_node(node) : m_node(node->shared_from_this())
, m_index(index) , m_index(index)
{ {
} }
...@@ -498,7 +498,7 @@ namespace ngraph ...@@ -498,7 +498,7 @@ namespace ngraph
/// ///
/// TODO: Make a plan to deprecate this. /// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<NodeType>& node, size_t index) Output(const std::shared_ptr<NodeType>& node, size_t index)
: m_node(node.get()) : m_node(node)
, m_index(index) , m_index(index)
{ {
} }
...@@ -511,12 +511,15 @@ namespace ngraph ...@@ -511,12 +511,15 @@ namespace ngraph
{ {
} }
// A null output
Output() = default;
/// \return A pointer to the node referred to by this output handle. /// \return A pointer to the node referred to by this output handle.
NodeType* get_node() const { return m_node; } NodeType* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle. /// \return A `shared_ptr` to the node referred to by this output handle.
/// ///
/// TODO: Make a plan to deprecate this. /// TODO: Make a plan to deprecate this.
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node->shared_from_this(); } std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return The index of the output referred to by this output handle. /// \return The index of the output referred to by this output handle.
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
/// \return A reference to the tensor descriptor for this output. /// \return A reference to the tensor descriptor for this output.
...@@ -568,8 +571,8 @@ namespace ngraph ...@@ -568,8 +571,8 @@ namespace ngraph
bool operator<=(const Output& other) const { return !(*this > other); } bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); } bool operator>=(const Output& other) const { return !(*this < other); }
private: private:
NodeType* const m_node; std::shared_ptr<NodeType> m_node;
const size_t m_index; size_t m_index{0};
}; };
inline Input<Node> Node::input(size_t input_index) inline Input<Node> Node::input(size_t input_index)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment