Commit fd881acc authored by Scott Cyphers's avatar Scott Cyphers

Review comments.

parent 064fb0fc
......@@ -49,7 +49,7 @@ namespace ngraph
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
const std::string& m_cname;
};
// Provides a compile-time name for a C++ type.
......@@ -62,7 +62,7 @@ namespace ngraph
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN(T) \
#define NGRAPH_DEFINE_TRAITED_TYPE_NAME(T) \
template <> \
constexpr const char* traited_type_name<T>() \
{ \
......@@ -95,25 +95,25 @@ namespace ngraph
}
};
NGRAPH_DEFINE_TTN(float)
using Float = TraitedType<float>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>;
NGRAPH_DEFINE_TTN(int8_t)
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int8_t)
using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TTN(int32_t)
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int32_t)
using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TTN(int64_t)
NGRAPH_DEFINE_TRAITED_TYPE_NAME(int64_t)
using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TTN(uint8_t)
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint8_t)
using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TTN(uint32_t)
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint32_t)
using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TTN(uint64_t)
NGRAPH_DEFINE_TRAITED_TYPE_NAME(uint64_t)
using UInt64 = TraitedType<uint64_t>;
}
}
......@@ -28,9 +28,12 @@ namespace ngraph
Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Node> result() { return m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; }
std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<Parameter>> get_parameters() const
{
return m_parameters;
}
std::string get_name() const { return m_name; }
protected:
std::shared_ptr<Node> m_result;
......
......@@ -18,9 +18,9 @@
size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type)
: TypedValueMixin(type)
, m_arguments(arguments)
std::shared_ptr<ValueType> value_type)
: m_arguments(arguments)
, m_value_type(value_type)
, m_instance_id(m_next_instance_id++)
{
// Add this node as a user of each argument.
......
......@@ -30,10 +30,19 @@ namespace ngraph
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
class Node : public std::enable_shared_from_this<Node>
{
protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr);
Node(const Nodes& arguments, std::shared_ptr<ValueType> value_type = nullptr);
Node()
: Node({}, nullptr)
{
}
Node(std::shared_ptr<ValueType> value_type)
: Node({}, value_type)
{
}
virtual ~Node() {}
......@@ -61,6 +70,19 @@ namespace ngraph
return typeid(*this) == typeid(*node.get());
}
std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
bool is_op() const;
bool is_parameter() const;
......@@ -68,10 +90,11 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&);
protected:
Nodes m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
Nodes m_arguments;
std::shared_ptr<ValueType> m_value_type;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
};
}
......@@ -80,7 +80,12 @@ namespace ngraph
{
public:
Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments, nullptr)
: Node(arguments)
{
}
Op()
: Node()
{
}
......
......@@ -63,7 +63,7 @@ namespace ngraph
typename T::type m_value;
};
using FloatScalarConstant = ScalarConstant<element::Float>;
using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>;
......
......@@ -82,46 +82,4 @@ namespace ngraph
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
};
/**
** Mixin for objects with type information
**/
class TypedValueMixin
{
public:
TypedValueMixin(const std::shared_ptr<ValueType>& value_type = nullptr)
: m_value_type(value_type)
{
}
/**
** Set the type
** /param type The new type
**/
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
/**
** The type associated with this value.
**/
std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
/**
** The type associated with this value.
**/
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
protected:
std::shared_ptr<ValueType> m_value_type;
};
}
......@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph;
Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node({}, value_type)
: Node(value_type)
, m_function(nullptr)
, m_index(0)
{
......
......@@ -23,10 +23,10 @@ using namespace ngraph;
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = node<Parameter>(element::Float::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0);
......@@ -35,14 +35,14 @@ TEST(build_graph, build_simple)
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot);
ASSERT_EQ(cluster_0->get_result(), dot);
}
// Check upcasting from ValueType.
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
auto tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5});
auto tv_vt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
......@@ -59,14 +59,14 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto arg0 = node<Parameter>(element::Float::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32});
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto parg = node<Parameter>(element::Float::element_type(), Shape{});
auto parg = node<Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
......@@ -78,8 +78,8 @@ TEST(build_graph, literal)
{
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
auto float0 = node<Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0);
......@@ -87,7 +87,7 @@ TEST(build_graph, literal)
ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int
auto float1 = node<FloatScalarConstant>(3);
auto float1 = node<Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
......
......@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op)
{
auto arg0 = op::parameter(element::Float::element_type(), {1});
auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
......@@ -31,7 +31,7 @@ TEST(op, is_op)
TEST(op, is_parameter)
{
auto arg0 = op::parameter(element::Float::element_type(), {1});
auto arg0 = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0);
ASSERT_NE(nullptr, t0);
......
......@@ -61,7 +61,7 @@ TEST(topological_sort, basic)
vector<shared_ptr<Parameter>> args;
for (int i = 0; i < 10; i++)
{
auto arg = op::parameter(element::Float::element_type(), {1});
auto arg = op::parameter(element::Float32::element_type(), {1});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment