Commit cb901c1f authored by Bob Kimball's avatar Bob Kimball

Change Input/Output to instances in node rather than shared_ptr

parent 13c281ae
...@@ -19,14 +19,14 @@ using namespace ngraph; ...@@ -19,14 +19,14 @@ using namespace ngraph;
using namespace descriptor; using namespace descriptor;
Input::Input( Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, const shared_ptr<Output>& output) Node* node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node) : m_node(node)
, m_index(index) , m_index(index)
, m_argno(argno) , m_argno(argno)
, m_arg_index(arg_index) , m_arg_index(arg_index)
, m_output(output) , m_output(output)
{ {
output->add_input(this); output.add_input(this);
} }
std::shared_ptr<Node> Input::get_node() std::shared_ptr<Node> Input::get_node()
......
...@@ -18,15 +18,16 @@ ...@@ -18,15 +18,16 @@
namespace ngraph namespace ngraph
{ {
class Node;
namespace descriptor namespace descriptor
{ {
class Output; class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple // Describes a tensor that is an input to an op, directly or indirectly via a tuple
class Input : public std::enable_shared_from_this<Input> class Input
{ {
Input(const Input&) = delete; friend class Node;
Input& operator=(const Input&) = delete;
public: public:
/// @param node The node that owns this input; not shared to prevent owner loop /// @param node The node that owns this input; not shared to prevent owner loop
...@@ -34,24 +35,29 @@ namespace ngraph ...@@ -34,24 +35,29 @@ namespace ngraph
/// @param argno The position of the argument with this tensor /// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors /// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input /// @param output The output that supplies a value for this input
Input(Node* node, Input(Node* node,
size_t index, size_t index,
size_t argno, size_t argno,
size_t arg_index, size_t arg_index,
const std::shared_ptr<Output>& output); Output& output);
std::shared_ptr<Node> get_node(); std::shared_ptr<Node> get_node();
size_t get_argno() const { return m_argno; } size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; } size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
std::shared_ptr<Output> get_output() const { return m_output; } const Output& get_output() const { return m_output; }
protected: protected:
Node* m_node; // The node we are an input for Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors size_t m_arg_index; // Index into arg's tensors
std::shared_ptr<Output> m_output; Output& m_output;
private:
// Input(const Input&) = default;
// Input(Input&&) = default;
// Input& operator=(const Input&) = delete;
}; };
} }
} }
...@@ -24,10 +24,10 @@ namespace ngraph ...@@ -24,10 +24,10 @@ namespace ngraph
namespace descriptor namespace descriptor
{ {
// Describes an output tensor of an op // Describes an output tensor of an op
class Output : public std::enable_shared_from_this<Output> class Output
{ {
Output(const Output&) = delete; // Output(const Output&) = delete;
Output& operator=(const Output&) = delete; // Output& operator=(const Output&) = delete;
public: public:
/// @param node Node that owns this output. Not shared to prevent owner loop. /// @param node Node that owns this output. Not shared to prevent owner loop.
......
...@@ -31,6 +31,10 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> ...@@ -31,6 +31,10 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
} }
} }
Node::~Node()
{
}
void Node::set_value_type_checked(const shared_ptr<ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<ValueType>& value_type)
{ {
if (nullptr == m_value_type) if (nullptr == m_value_type)
...@@ -54,8 +58,7 @@ void Node::assign_tensors() ...@@ -54,8 +58,7 @@ void Node::assign_tensors()
for (auto tvt : tensor_view_types) for (auto tvt : tensor_view_types)
{ {
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt); auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt);
auto output = make_shared<descriptor::Output>(this, i++, tensor_view_descriptor); m_outputs.emplace_back(this, i++, tensor_view_descriptor);
m_outputs.push_back(output);
} }
i = 0; i = 0;
...@@ -63,10 +66,9 @@ void Node::assign_tensors() ...@@ -63,10 +66,9 @@ void Node::assign_tensors()
for (auto arg : get_arguments()) for (auto arg : get_arguments())
{ {
size_t arg_index = 0; size_t arg_index = 0;
for (auto output : arg->get_outputs()) for (descriptor::Output& output : arg->get_outputs())
{ {
auto input = make_shared<descriptor::Input>(this, i++, argno, arg_index++, output); m_inputs.emplace_back(this, i++, argno, arg_index++, output);
m_inputs.push_back(input);
} }
argno++; argno++;
} }
......
...@@ -50,7 +50,7 @@ namespace ngraph ...@@ -50,7 +50,7 @@ namespace ngraph
{ {
} }
virtual ~Node() {} virtual ~Node();
public: public:
/// A "one-liner" describing this node. /// A "one-liner" describing this node.
...@@ -106,17 +106,17 @@ namespace ngraph ...@@ -106,17 +106,17 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
const std::vector<std::shared_ptr<descriptor::Input>>& get_inputs() { return m_inputs; } std::vector<descriptor::Input>& get_inputs() { return m_inputs; }
const std::vector<std::shared_ptr<descriptor::Output>>& get_outputs() { return m_outputs; } std::vector<descriptor::Output>& get_outputs() { return m_outputs; }
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::shared_ptr<ValueType> m_value_type; std::shared_ptr<ValueType> m_value_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
std::vector<std::shared_ptr<descriptor::Input>> m_inputs; std::vector<descriptor::Input> m_inputs;
std::vector<std::shared_ptr<descriptor::Output>> m_outputs; std::vector<descriptor::Output> m_outputs;
}; };
} }
...@@ -32,11 +32,11 @@ TEST(input_output, param_tensor) ...@@ -32,11 +32,11 @@ TEST(input_output, param_tensor)
for (size_t i = 0; i < param->get_outputs().size(); i++) for (size_t i = 0; i < param->get_outputs().size(); i++)
{ {
auto output = param->get_outputs()[i]; auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index()); ASSERT_EQ(i, output.get_index());
ASSERT_EQ(param, output->get_node()); ASSERT_EQ(param, output.get_node());
} }
ASSERT_EQ(*tv_tp, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type()); ASSERT_EQ(*tv_tp, *param->get_outputs()[0].get_tensor_view()->get_tensor_view_type());
} }
TEST(input_output, param_tuple) TEST(input_output, param_tuple)
...@@ -53,12 +53,12 @@ TEST(input_output, param_tuple) ...@@ -53,12 +53,12 @@ TEST(input_output, param_tuple)
for (size_t i = 0; i < param->get_outputs().size(); i++) for (size_t i = 0; i < param->get_outputs().size(); i++)
{ {
auto output = param->get_outputs()[i]; auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index()); ASSERT_EQ(i, output.get_index());
ASSERT_EQ(param, output->get_node()); ASSERT_EQ(param, output.get_node());
} }
ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type()); ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0].get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1]->get_tensor_view()->get_tensor_view_type()); ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1].get_tensor_view()->get_tensor_view_type());
} }
TEST(input_output, simple_output) TEST(input_output, simple_output)
...@@ -92,9 +92,9 @@ TEST(input_output, simple_output) ...@@ -92,9 +92,9 @@ TEST(input_output, simple_output)
for (size_t i = 0; i < inputs.size(); i++) for (size_t i = 0; i < inputs.size(); i++)
{ {
auto input = inputs[i]; auto input = inputs[i];
ASSERT_EQ(i, input->get_index()); ASSERT_EQ(i, input.get_index());
ASSERT_EQ(i, input->get_argno()); ASSERT_EQ(i, input.get_argno());
ASSERT_EQ(0, input->get_arg_index()); ASSERT_EQ(0, input.get_arg_index());
ASSERT_EQ(input->get_output()->get_node(), add->get_arguments()[i]); ASSERT_EQ(input.get_output().get_node(), add->get_arguments()[i]);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment