Commit eb22cbb4 authored by Scott Cyphers's avatar Scott Cyphers

Use weak_ptr for node in inputs/outputs, turn off alignment style.

parent c95c8b68
...@@ -6,8 +6,8 @@ Standard: Cpp11 ...@@ -6,8 +6,8 @@ Standard: Cpp11
AccessModifierOffset: -4 AccessModifierOffset: -4
AlignConsecutiveDeclarations: true AlignConsecutiveDeclarations: false
AlignConsecutiveAssignments: true AlignConsecutiveAssignments: false
AlignTrailingComments: true AlignTrailingComments: true
AllowShortBlocksOnASingleLine: true AllowShortBlocksOnASingleLine: true
......
...@@ -19,7 +19,7 @@ using namespace ngraph; ...@@ -19,7 +19,7 @@ using namespace ngraph;
using namespace descriptor; using namespace descriptor;
Input::Input( Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, Output& output) const std::shared_ptr<Node>& node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node) : m_node(node)
, m_index(index) , m_index(index)
, m_argno(argno) , m_argno(argno)
...@@ -31,7 +31,7 @@ Input::Input( ...@@ -31,7 +31,7 @@ Input::Input(
std::shared_ptr<Node> Input::get_node() std::shared_ptr<Node> Input::get_node()
{ {
return m_node->shared_from_this(); return m_node.lock();
} }
const Tensor& Input::get_tensor() const const Tensor& Input::get_tensor() const
......
...@@ -32,31 +32,31 @@ namespace ngraph ...@@ -32,31 +32,31 @@ namespace ngraph
friend class Node; friend class Node;
public: public:
/// @param node The node that owns this input; not shared to prevent owner loop /// @param node The node that owns this input
/// @param index The position of this this tensor in all input tensors /// @param index The position of this this tensor in all input tensors
/// @param argno The position of the argument with this tensor /// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors /// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input /// @param output The output that supplies a value for this input
Input(Node* node, Input(const std::shared_ptr<Node>& node,
size_t index, size_t index,
size_t argno, size_t argno,
size_t arg_index, size_t arg_index,
Output& output); Output& output);
std::shared_ptr<Node> get_node(); std::shared_ptr<Node> get_node();
size_t get_argno() const { return m_argno; } size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; } size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
const Output& get_output() const { return m_output; } const Output& get_output() const { return m_output; }
Output& get_output() { return m_output; } Output& get_output() { return m_output; }
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
Tensor& get_tensor(); Tensor& get_tensor();
protected: protected:
Node* m_node; // The node we are an input for std::weak_ptr<Node> m_node; // The node we are an input for
size_t m_index; // Index into all input tensors size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors size_t m_arg_index; // Index into arg's tensors
Output& m_output; Output& m_output;
private: private:
......
...@@ -18,7 +18,9 @@ using namespace std; ...@@ -18,7 +18,9 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
Output::Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view) Output::Output(const std::shared_ptr<Node>& node,
size_t index,
const std::shared_ptr<TensorView>& tensor_view)
: m_node(node) : m_node(node)
, m_index(index) , m_index(index)
, m_tensor_view(tensor_view) , m_tensor_view(tensor_view)
...@@ -33,7 +35,7 @@ void Output::add_input(Input* input) ...@@ -33,7 +35,7 @@ void Output::add_input(Input* input)
std::shared_ptr<Node> Output::get_node() const std::shared_ptr<Node> Output::get_node() const
{ {
return m_node->shared_from_this(); return m_node.lock();
} }
const Tensor& Output::get_tensor() const const Tensor& Output::get_tensor() const
......
...@@ -32,24 +32,26 @@ namespace ngraph ...@@ -32,24 +32,26 @@ namespace ngraph
// Output& operator=(const Output&) = delete; // Output& operator=(const Output&) = delete;
public: public:
/// @param node Node that owns this output. Not shared to prevent owner loop. /// @param node Node that owns this output.
/// @param index Position of the output tensor in all output tensors /// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written /// @param tensor_view The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view); Output(const std::shared_ptr<Node>& node,
size_t index,
const std::shared_ptr<TensorView>& tensor_view);
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; } std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input); void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
Tensor& get_tensor(); Tensor& get_tensor();
protected: protected:
Node* m_node; std::weak_ptr<Node> m_node;
size_t m_index; size_t m_index;
std::shared_ptr<TensorView> m_tensor_view; std::shared_ptr<TensorView> m_tensor_view;
std::set<Input*> m_inputs; std::set<Input*> m_inputs;
}; };
} }
} }
...@@ -32,9 +32,7 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> ...@@ -32,9 +32,7 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
} }
} }
Node::~Node() Node::~Node() {}
{
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{ {
...@@ -55,22 +53,27 @@ void Node::assign_tensors() ...@@ -55,22 +53,27 @@ void Node::assign_tensors()
{ {
vector<std::shared_ptr<const TensorViewType>> tensor_view_types; vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types); get_value_type()->collect_tensor_views(tensor_view_types);
std::shared_ptr<Node> shared_this = shared_from_this();
size_t i = 0; size_t i = 0;
for (auto tvt : tensor_view_types) for (auto tvt : tensor_view_types)
{ {
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt, ngraph::descriptor::Tensor::make_tensor_name(this, i), is_output(), is_parameter()); auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(
m_outputs.emplace_back(this, i, tensor_view_descriptor); tvt,
ngraph::descriptor::Tensor::make_tensor_name(this, i),
is_output(),
is_parameter());
m_outputs.emplace_back(shared_this, i, tensor_view_descriptor);
i++; i++;
} }
i = 0; i = 0;
size_t argno = 0; size_t argno = 0;
for (auto arg : get_arguments()) for (auto arg : get_arguments())
{ {
size_t arg_index = 0; size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs()) for (descriptor::Output& output : arg->get_outputs())
{ {
m_inputs.emplace_back(this, i, argno, arg_index++, output); m_inputs.emplace_back(shared_this, i, argno, arg_index++, output);
i++; i++;
} }
argno++; argno++;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment