Commit cb901c1f authored by Bob Kimball's avatar Bob Kimball

Change Input/Output to instances in node rather than shared_ptr

parent 13c281ae
......@@ -19,14 +19,14 @@ using namespace ngraph;
using namespace descriptor;
Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, const shared_ptr<Output>& output)
Node* node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
, m_arg_index(arg_index)
, m_output(output)
{
output->add_input(this);
output.add_input(this);
}
std::shared_ptr<Node> Input::get_node()
......
......@@ -18,15 +18,16 @@
namespace ngraph
{
class Node;
namespace descriptor
{
class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class Input : public std::enable_shared_from_this<Input>
class Input
{
Input(const Input&) = delete;
Input& operator=(const Input&) = delete;
friend class Node;
public:
/// @param node The node that owns this input; not shared to prevent owner loop
......@@ -38,20 +39,25 @@ namespace ngraph
size_t index,
size_t argno,
size_t arg_index,
const std::shared_ptr<Output>& output);
Output& output);
std::shared_ptr<Node> get_node();
size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; }
std::shared_ptr<Output> get_output() const { return m_output; }
const Output& get_output() const { return m_output; }
protected:
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
std::shared_ptr<Output> m_output;
Output& m_output;
private:
// Input(const Input&) = default;
// Input(Input&&) = default;
// Input& operator=(const Input&) = delete;
};
}
}
......@@ -24,10 +24,10 @@ namespace ngraph
namespace descriptor
{
// Describes an output tensor of an op
class Output : public std::enable_shared_from_this<Output>
class Output
{
Output(const Output&) = delete;
Output& operator=(const Output&) = delete;
// Output(const Output&) = delete;
// Output& operator=(const Output&) = delete;
public:
/// @param node Node that owns this output. Not shared to prevent owner loop.
......
......@@ -31,6 +31,10 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType>
}
}
Node::~Node()
{
}
void Node::set_value_type_checked(const shared_ptr<ValueType>& value_type)
{
if (nullptr == m_value_type)
......@@ -54,8 +58,7 @@ void Node::assign_tensors()
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt);
auto output = make_shared<descriptor::Output>(this, i++, tensor_view_descriptor);
m_outputs.push_back(output);
m_outputs.emplace_back(this, i++, tensor_view_descriptor);
}
i = 0;
......@@ -63,10 +66,9 @@ void Node::assign_tensors()
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (auto output : arg->get_outputs())
for (descriptor::Output& output : arg->get_outputs())
{
auto input = make_shared<descriptor::Input>(this, i++, argno, arg_index++, output);
m_inputs.push_back(input);
m_inputs.emplace_back(this, i++, argno, arg_index++, output);
}
argno++;
}
......
......@@ -50,7 +50,7 @@ namespace ngraph
{
}
virtual ~Node() {}
virtual ~Node();
public:
/// A "one-liner" describing this node.
......@@ -106,8 +106,8 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
const std::vector<std::shared_ptr<descriptor::Input>>& get_inputs() { return m_inputs; }
const std::vector<std::shared_ptr<descriptor::Output>>& get_outputs() { return m_outputs; }
std::vector<descriptor::Input>& get_inputs() { return m_inputs; }
std::vector<descriptor::Output>& get_outputs() { return m_outputs; }
protected:
Nodes m_arguments;
......@@ -116,7 +116,7 @@ namespace ngraph
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
std::vector<std::shared_ptr<descriptor::Input>> m_inputs;
std::vector<std::shared_ptr<descriptor::Output>> m_outputs;
std::vector<descriptor::Input> m_inputs;
std::vector<descriptor::Output> m_outputs;
};
}
......@@ -32,11 +32,11 @@ TEST(input_output, param_tensor)
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
ASSERT_EQ(i, output.get_index());
ASSERT_EQ(param, output.get_node());
}
ASSERT_EQ(*tv_tp, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp, *param->get_outputs()[0].get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, param_tuple)
......@@ -53,12 +53,12 @@ TEST(input_output, param_tuple)
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
ASSERT_EQ(i, output.get_index());
ASSERT_EQ(param, output.get_node());
}
ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1]->get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0].get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1].get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, simple_output)
......@@ -92,9 +92,9 @@ TEST(input_output, simple_output)
for (size_t i = 0; i < inputs.size(); i++)
{
auto input = inputs[i];
ASSERT_EQ(i, input->get_index());
ASSERT_EQ(i, input->get_argno());
ASSERT_EQ(0, input->get_arg_index());
ASSERT_EQ(input->get_output()->get_node(), add->get_arguments()[i]);
ASSERT_EQ(i, input.get_index());
ASSERT_EQ(i, input.get_argno());
ASSERT_EQ(0, input.get_arg_index());
ASSERT_EQ(input.get_output().get_node(), add->get_arguments()[i]);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment