Commit fd881acc authored by Scott Cyphers's avatar Scott Cyphers

Review comments.

parent 064fb0fc
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string& m_cname;
}; };
// Provides a compile-time name for a C++ type. // Provides a compile-time name for a C++ type.
...@@ -62,7 +62,7 @@ namespace ngraph ...@@ -62,7 +62,7 @@ namespace ngraph
} }
// Define a type string for a type T. Will make traited_type_name<T>() return "T" // Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN(T) \ #define NGRAPH_DEFINE_TRAITED_TYPE_NAME(T) \
template <> \ template <> \
constexpr const char* traited_type_name<T>() \ constexpr const char* traited_type_name<T>() \
{ \ { \
...@@ -95,25 +95,25 @@ namespace ngraph ...@@ -95,25 +95,25 @@ namespace ngraph
} }
}; };
NGRAPH_DEFINE_TTN(float) NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float = TraitedType<float>; using Float32 = TraitedType<float>;
NGRAPH_DEFINE_TTN(int8_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t)
using Int8 = TraitedType<int8_t>; using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TTN(int32_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(int32_t)
using Int32 = TraitedType<int32_t>; using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TTN(int64_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(int64_t)
using Int64 = TraitedType<int64_t>; using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TTN(uint8_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint8_t)
using UInt8 = TraitedType<uint8_t>; using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TTN(uint32_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint32_t)
using UInt32 = TraitedType<uint32_t>; using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TTN(uint64_t) NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint64_t)
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
...@@ -28,9 +28,12 @@ namespace ngraph ...@@ -28,9 +28,12 @@ namespace ngraph
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Node> result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; } const std::vector<std::shared_ptr<Parameter>> get_parameters() const
std::string name() const { return m_name; } {
return m_parameters;
}
std::string get_name() const { return m_name; }
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
size_t ngraph::Node::m_next_instance_id = 0; size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type) std::shared_ptr<ValueType> value_type)
: TypedValueMixin(type) : m_arguments(arguments)
, m_arguments(arguments) , m_value_type(value_type)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
{ {
// Add this node as a user of each argument. // Add this node as a user of each argument.
......
...@@ -30,10 +30,19 @@ namespace ngraph ...@@ -30,10 +30,19 @@ namespace ngraph
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node> class Node : public std::enable_shared_from_this<Node>
{ {
protected: protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr); Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr);
Node()
: Node({}, nullptr)
{
}
Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
{
}
virtual ~Node() {} virtual ~Node() {}
...@@ -61,6 +70,19 @@ namespace ngraph ...@@ -61,6 +70,19 @@ namespace ngraph
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
} }
std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
bool is_op() const; bool is_op() const;
bool is_parameter() const; bool is_parameter() const;
...@@ -69,6 +91,7 @@ namespace ngraph ...@@ -69,6 +91,7 @@ namespace ngraph
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::shared_ptr<ValueType> m_value_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
......
...@@ -80,7 +80,12 @@ namespace ngraph ...@@ -80,7 +80,12 @@ namespace ngraph
{ {
public: public:
Op(const std::vector<std::shared_ptr<Node>>& arguments) Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments, nullptr) : Node(arguments)
{
}
Op()
: Node()
{ {
} }
......
...@@ -63,7 +63,7 @@ namespace ngraph ...@@ -63,7 +63,7 @@ namespace ngraph
typename T::type m_value; typename T::type m_value;
}; };
using FloatScalarConstant = ScalarConstant<element::Float>; using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>; using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>; using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>; using Int64ScalarConstant = ScalarConstant<element::Int64>;
......
...@@ -82,46 +82,4 @@ namespace ngraph ...@@ -82,46 +82,4 @@ namespace ngraph
protected: protected:
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<std::shared_ptr<ValueType>> m_element_types;
}; };
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const std::shared_ptr<ValueType>& value_type = nullptr)
: m_value_type(value_type)
{
}
/**
** Set the type
** /param type The new type
**/
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
/**
** The type associated with this value.
**/
std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
/**
** The type associated with this value.
**/
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
protected:
std::shared_ptr<ValueType> m_value_type;
};
} }
...@@ -20,7 +20,7 @@ using namespace std; ...@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
Parameter::Parameter(const std::shared_ptr<ValueType>& value_type) Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node({}, value_type) : Node(value_type)
, m_function(nullptr) , m_function(nullptr)
, m_index(0) , m_index(0)
{ {
......
...@@ -23,10 +23,10 @@ using namespace ngraph; ...@@ -23,10 +23,10 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = node<Parameter>(element::Float::element_type(), Shape{7, 3}); auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3}); auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32, 7}); auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float::element_type(), Shape{32, 7}); auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0); auto dot = node<DotOp>(arg2, arg0);
...@@ -35,14 +35,14 @@ TEST(build_graph, build_simple) ...@@ -35,14 +35,14 @@ TEST(build_graph, build_simple)
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3}); auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot); ASSERT_EQ(cluster_0->get_result(), dot);
} }
// Check upcasting from ValueType. // Check upcasting from ValueType.
TEST(build_graph, as_type) TEST(build_graph, as_type)
{ {
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
auto tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5}); auto tv_vt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt); auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv); ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt); auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
...@@ -59,14 +59,14 @@ TEST(build_graph, as_type) ...@@ -59,14 +59,14 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = node<Parameter>(element::Float::element_type(), Shape{32, 3}); auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3}); auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32}); auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto parg = node<Parameter>(element::Float::element_type(), Shape{}); auto parg = node<Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg); auto pattern_dot = node<DotOp>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
...@@ -78,8 +78,8 @@ TEST(build_graph, literal) ...@@ -78,8 +78,8 @@ TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0); auto float0 = node<Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0); ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type); ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0); auto d = node<DotOp>(float0, float0);
...@@ -87,7 +87,7 @@ TEST(build_graph, literal) ...@@ -87,7 +87,7 @@ TEST(build_graph, literal)
ASSERT_EQ(d->get_arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = node<FloatScalarConstant>(3); auto float1 = node<Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type); ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
......
...@@ -23,7 +23,7 @@ using namespace ngraph; ...@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op) TEST(op, is_op)
{ {
auto arg0 = op::parameter(element::Float::element_type(), {1}); auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter()); EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op()); EXPECT_FALSE(arg0->is_op());
...@@ -31,7 +31,7 @@ TEST(op, is_op) ...@@ -31,7 +31,7 @@ TEST(op, is_op)
TEST(op, is_parameter) TEST(op, is_parameter)
{ {
auto arg0 = op::parameter(element::Float::element_type(), {1}); auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0); auto t0 = op::add(arg0, arg0);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
......
...@@ -61,7 +61,7 @@ TEST(topological_sort, basic) ...@@ -61,7 +61,7 @@ TEST(topological_sort, basic)
vector<shared_ptr<Parameter>> args; vector<shared_ptr<Parameter>> args;
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
{ {
auto arg = op::parameter(element::Float::element_type(), {1}); auto arg = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg); ASSERT_NE(nullptr, arg);
args.push_back(arg); args.push_back(arg);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment