Commit a90f6bf4 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #159 from NervanaSystems/cyphers/node

Use weak_ptr for node in inputs/outputs, turn off alignment style.
parents c95c8b68 eb22cbb4
......@@ -6,8 +6,8 @@ Standard: Cpp11
AccessModifierOffset: -4
AlignConsecutiveDeclarations: true
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignConsecutiveAssignments: false
AlignTrailingComments: true
AllowShortBlocksOnASingleLine: true
......
......@@ -19,7 +19,7 @@ using namespace ngraph;
using namespace descriptor;
Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, Output& output)
const std::shared_ptr<Node>& node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
......@@ -31,7 +31,7 @@ Input::Input(
std::shared_ptr<Node> Input::get_node()
{
return m_node->shared_from_this();
return m_node.lock();
}
const Tensor& Input::get_tensor() const
......
......@@ -32,31 +32,31 @@ namespace ngraph
friend class Node;
public:
/// @param node The node that owns this input; not shared to prevent owner loop
/// @param node The node that owns this input
/// @param index The position of this this tensor in all input tensors
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(Node* node,
size_t index,
size_t argno,
size_t arg_index,
Input(const std::shared_ptr<Node>& node,
size_t index,
size_t argno,
size_t arg_index,
Output& output);
std::shared_ptr<Node> get_node();
size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; }
const Output& get_output() const { return m_output; }
Output& get_output() { return m_output; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; }
const Output& get_output() const { return m_output; }
Output& get_output() { return m_output; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
protected:
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
std::weak_ptr<Node> m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
Output& m_output;
private:
......
......@@ -18,7 +18,9 @@ using namespace std;
using namespace ngraph;
using namespace descriptor;
Output::Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view)
Output::Output(const std::shared_ptr<Node>& node,
size_t index,
const std::shared_ptr<TensorView>& tensor_view)
: m_node(node)
, m_index(index)
, m_tensor_view(tensor_view)
......@@ -33,7 +35,7 @@ void Output::add_input(Input* input)
std::shared_ptr<Node> Output::get_node() const
{
return m_node->shared_from_this();
return m_node.lock();
}
const Tensor& Output::get_tensor() const
......
......@@ -32,24 +32,26 @@ namespace ngraph
// Output& operator=(const Output&) = delete;
public:
/// @param node Node that owns this output. Not shared to prevent owner loop.
/// @param node Node that owns this output.
/// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
Output(const std::shared_ptr<Node>& node,
size_t index,
const std::shared_ptr<TensorView>& tensor_view);
std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; }
std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
const Tensor& get_tensor() const;
Tensor& get_tensor();
protected:
Node* m_node;
size_t m_index;
std::weak_ptr<Node> m_node;
size_t m_index;
std::shared_ptr<TensorView> m_tensor_view;
std::set<Input*> m_inputs;
std::set<Input*> m_inputs;
};
}
}
......@@ -32,9 +32,7 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
}
}
Node::~Node()
{
}
Node::~Node() {}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{
......@@ -55,22 +53,27 @@ void Node::assign_tensors()
{
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
std::shared_ptr<Node> shared_this = shared_from_this();
size_t i = 0;
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt, ngraph::descriptor::Tensor::make_tensor_name(this, i), is_output(), is_parameter());
m_outputs.emplace_back(this, i, tensor_view_descriptor);
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(
tvt,
ngraph::descriptor::Tensor::make_tensor_name(this, i),
is_output(),
is_parameter());
m_outputs.emplace_back(shared_this, i, tensor_view_descriptor);
i++;
}
i = 0;
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(this, i, argno, arg_index++, output);
m_inputs.emplace_back(shared_this, i, argno, arg_index++, output);
i++;
}
argno++;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment