Commit f1608316 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #98 from NervanaSystems/cyphers/tensors

Add input/output tensor information to nodes.
parents e92e5e5b 320276bf
...@@ -15,6 +15,10 @@ set (SRC ...@@ -15,6 +15,10 @@ set (SRC
tree.cpp tree.cpp
util.cpp util.cpp
log.cpp log.cpp
ngraph/descriptor/input.cpp
ngraph/descriptor/output.cpp
ngraph/descriptor/tensor.cpp
ngraph/descriptor/tensor_view.cpp
ops/binary_elementwise_builtin.cpp ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace descriptor;
Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, const shared_ptr<Output>& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
, m_arg_index(arg_index)
, m_output(output)
{
output->add_input(this);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
namespace ngraph
{
namespace descriptor
{
class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class Input : public std::enable_shared_from_this<Input>
{
Input(const Input&) = delete;
Input& operator=(const Input&) = delete;
public:
/// @param node The node that owns this input; not shared to prevent owner loop
/// @param index The position of this this tensor in all input tensors
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(Node* node,
size_t index,
size_t argno,
size_t arg_index,
const std::shared_ptr<Output>& output);
std::shared_ptr<Node> get_node() { return m_node->shared_from_this(); }
size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; }
std::shared_ptr<Output> get_output() const { return m_output; }
protected:
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
std::shared_ptr<Output> m_output;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace descriptor;
Output::Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view)
: m_node(node)
, m_index(index)
, m_tensor_view(tensor_view)
{
}
// Add an input to the vector of inputs that use this output.
void Output::add_input(Input* input)
{
m_inputs.insert(input);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <set>
#include "descriptor/tensor_view.hpp"
namespace ngraph
{
namespace descriptor
{
// Describes an output tensor of an op
class Output : public std::enable_shared_from_this<Output>
{
Output(const Output&) = delete;
Output& operator=(const Output&) = delete;
public:
/// @param node Node that owns this output. Not shared to prevent owner loop.
/// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
std::shared_ptr<Node> get_node() const { return m_node->shared_from_this(); }
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
protected:
Node* m_node;
size_t m_index;
std::shared_ptr<TensorView> m_tensor_view;
std::set<Input*> m_inputs;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor.hpp"
using namespace ngraph;
using namespace descriptor;
Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view)
: m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
namespace ngraph
{
namespace element
{
class Type;
}
namespace descriptor
{
class TensorView;
class PrimaryTensorView;
class Tensor
{
friend class PrimaryTensorView;
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view);
protected:
const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor_view.hpp"
using namespace ngraph;
using namespace descriptor;
const Tensor& PrimaryTensorView::get_tensor() const
{
return m_tensor;
}
Tensor& PrimaryTensorView::get_tensor()
{
return m_tensor;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "descriptor/tensor.hpp"
#include "shape.hpp"
#include "type.hpp"
namespace ngraph
{
namespace descriptor
{
class Tensor;
class TensorViewLayout;
// Describes a view of an instantiated tensor
class TensorView : public std::enable_shared_from_this<TensorView>
{
TensorView(const TensorView&) = delete;
TensorView& operator=(const TensorView&) = delete;
protected:
TensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type)
: m_tensor_view_type(tensor_view_type)
{
}
public:
virtual ~TensorView() {}
virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0;
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return m_tensor_view_type;
}
const std::shared_ptr<TensorViewLayout>& get_tensor_view_layout() const
{
return m_tensor_view_layout;
}
void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout)
{
m_tensor_view_layout = tensor_view_layout;
}
protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type;
std::shared_ptr<TensorViewLayout> m_tensor_view_layout;
};
// A PrimaryTensorView owns the tensor. All other views are the result
// of some index operation on the primary view.
class PrimaryTensorView : public TensorView
{
public:
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type)
: TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this)
{
}
virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override;
protected:
Tensor m_tensor;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
namespace ngraph
{
namespace descriptor
{
using Strides = std::vector<size_t>;
class TensorViewLayout
{
protected:
Strides m_strides;
};
}
}
...@@ -23,6 +23,11 @@ ...@@ -23,6 +23,11 @@
#include "except.hpp" #include "except.hpp"
#include "function.hpp" #include "function.hpp"
#include "node.hpp" #include "node.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "descriptor/tensor_view.hpp"
#include "descriptor/tensor_view_layout.hpp"
#include "descriptor/tensor.hpp"
#include "op.hpp" #include "op.hpp"
#include "ops/add.hpp" #include "ops/add.hpp"
#include "ops/broadcast.hpp" #include "ops/broadcast.hpp"
......
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp" #include "ngraph.hpp"
#include "op.hpp"
size_t ngraph::Node::m_next_instance_id = 0; using namespace std;
using namespace ngraph;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, size_t Node::m_next_instance_id = 0;
std::shared_ptr<ValueType> value_type)
Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> value_type)
: m_arguments(arguments) : m_arguments(arguments)
, m_value_type(value_type) , m_value_type(value_type)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
...@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, ...@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
} }
} }
void ngraph::Node::set_value_type_checked(const std::shared_ptr<ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<ValueType>& value_type)
{ {
if (nullptr == m_value_type){ if (nullptr == m_value_type)
{
m_value_type = value_type; m_value_type = value_type;
} else { }
if (*m_value_type != *value_type){ else
throw ngraph::ngraph_error("Setting value type to a different ValueType"); {
if (*m_value_type != *value_type)
{
throw ngraph_error("Setting value type to a different ValueType");
}
}
}
void Node::assign_tensors()
{
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
size_t i = 0;
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt);
auto output = make_shared<descriptor::Output>(this, i++, tensor_view_descriptor);
m_outputs.push_back(output);
}
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (auto output : arg->get_outputs())
{
auto input = make_shared<descriptor::Input>(this, i++, argno, arg_index++, output);
m_inputs.push_back(input);
} }
argno++;
} }
} }
bool ngraph::Node::is_op() const bool Node::is_op() const
{ {
return dynamic_cast<const ngraph::Op*>(this) != nullptr; return dynamic_cast<const Op*>(this) != nullptr;
} }
bool ngraph::Node::is_parameter() const bool Node::is_parameter() const
{ {
return dynamic_cast<const ngraph::op::Parameter*>(this) != nullptr; return dynamic_cast<const op::Parameter*>(this) != nullptr;
} }
namespace ngraph namespace ngraph
{ {
std::ostream& operator<<(std::ostream& out, const ngraph::Node& node) ostream& operator<<(ostream& out, const Node& node)
{ {
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node); auto op_tmp = dynamic_cast<const Op*>(&node);
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node); auto parameter_tmp = dynamic_cast<const Op*>(&node);
if (op_tmp) if (op_tmp)
{ {
out << "Op(" << op_tmp->get_node_id() << ")"; out << "Op(" << op_tmp->get_node_id() << ")";
......
...@@ -27,6 +27,12 @@ namespace ngraph ...@@ -27,6 +27,12 @@ namespace ngraph
{ {
class Op; class Op;
namespace descriptor
{
class Input;
class Output;
}
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
...@@ -53,6 +59,10 @@ namespace ngraph ...@@ -53,6 +59,10 @@ namespace ngraph
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
virtual void propagate_types() = 0; virtual void propagate_types() = 0;
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
...@@ -94,7 +104,9 @@ namespace ngraph ...@@ -94,7 +104,9 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<std::shared_ptr<descriptor::Input>> get_inputs() { return m_inputs; }
std::vector<std::shared_ptr<descriptor::Output>> get_outputs() {return m_outputs; }
protected: protected:
Nodes m_arguments; Nodes m_arguments;
...@@ -103,5 +115,7 @@ namespace ngraph ...@@ -103,5 +115,7 @@ namespace ngraph
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
std::vector<std::shared_ptr<descriptor::Input>> m_inputs;
std::vector<std::shared_ptr<descriptor::Output>> m_outputs;
}; };
} }
...@@ -34,10 +34,13 @@ namespace ngraph ...@@ -34,10 +34,13 @@ namespace ngraph
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType& that) const = 0; virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const { return !(*this == that); } bool operator!=(const ValueType& that) const { return !(*this == that); }
/// Add tensor views in depth-first order.
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
}; };
/// Describes a tensor view; an element type and a shape. /// Describes a tensor view; an element type and a shape.
class TensorViewType : public ValueType class TensorViewType : public ValueType, public std::enable_shared_from_this<TensorViewType>
{ {
public: public:
/// /param element_type The type of the tensor elements. /// /param element_type The type of the tensor elements.
...@@ -52,6 +55,7 @@ namespace ngraph ...@@ -52,6 +55,7 @@ namespace ngraph
const Shape& get_shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
...@@ -78,6 +82,7 @@ namespace ngraph ...@@ -78,6 +82,7 @@ namespace ngraph
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; } std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
protected: protected:
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<std::shared_ptr<ValueType>> m_element_types;
......
...@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const ...@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const
return true; return true;
} }
void TensorViewType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
views.push_back(shared_from_this());
}
bool TupleType::operator==(const ValueType& that) const bool TupleType::operator==(const ValueType& that) const
{ {
auto that_tvt = dynamic_cast<const TupleType*>(&that); auto that_tvt = dynamic_cast<const TupleType*>(&that);
...@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const ...@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const
} }
return that_tvt->get_element_types() == get_element_types(); return that_tvt->get_element_types() == get_element_types();
} }
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
for(auto elt : m_element_types){
elt->collect_tensor_views(views);
}
}
...@@ -24,14 +24,15 @@ include_directories( ...@@ -24,14 +24,15 @@ include_directories(
set (SRC set (SRC
main.cpp main.cpp
build_graph.cpp build_graph.cpp
util.cpp eigen.cpp
tensor.cpp
element_type.cpp element_type.cpp
uuid.cpp op.cpp
input_output_assign.cpp
tensor.cpp
topological_sort.cpp topological_sort.cpp
type_prop.cpp type_prop.cpp
op.cpp util.cpp
eigen.cpp uuid.cpp
) )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(input_output, param_tensor)
{
// Params have no arguments, so we can check that the value becomes a tensor output
auto tv_tp = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param = make_shared<op::Parameter>(tv_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 1);
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
}
ASSERT_EQ(*tv_tp, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, param_tuple)
{
// Same as param_tensor, but for a tuple
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto tv_tp_1 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4, 6});
auto tp_tp = make_shared<TupleType>(std::vector<std::shared_ptr<ValueType>>{tv_tp_0, tv_tp_1});
auto param = make_shared<op::Parameter>(tp_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 2);
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
}
ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1]->get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, simple_output)
{
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param_0 = make_shared<op::Parameter>(tv_tp_0);
auto param_1 = make_shared<op::Parameter>(tv_tp_0);
auto add = make_shared<op::Add>(param_0, param_1);
// Sort the ops
vector<shared_ptr<Node>> nodes;
nodes.push_back(param_0);
nodes.push_back(param_1);
nodes.push_back(add);
// Type info
for (auto node : nodes)
{
node->propagate_types();
}
// Add inputs/outputs
for (auto node : nodes)
{
node->assign_tensors();
}
// At this point, the add should have each input associated with the output of the appropriate parameter
auto inputs = add->get_inputs();
ASSERT_EQ(2, inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
auto input = inputs[i];
ASSERT_EQ(i, input->get_index());
ASSERT_EQ(i, input->get_argno());
ASSERT_EQ(0, input->get_arg_index());
ASSERT_EQ(input->get_output()->get_node(), add->get_arguments()[i]);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment