Commit f106a582 authored by Scott Cyphers's avatar Scott Cyphers

Rework element type to not depend on static initializers

parent c21b3f88
...@@ -39,8 +39,8 @@ namespace ngraph ...@@ -39,8 +39,8 @@ namespace ngraph
return h(m_cname); return h(m_cname);
} }
bool operator==(const Type& other) const; //bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } //bool operator!=(const Type& other) const { return !(*this == other); }
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
...@@ -53,15 +53,10 @@ namespace ngraph ...@@ -53,15 +53,10 @@ namespace ngraph
// Literals (and probably other things we don't know about yet) need to have their C++ types // Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides // and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation. // access to both the instance and the C++ type used to hold the value during compilation.
template <typename T> template <typename T, typename U>
class TraitedType : public Type class TraitedType : public Type
{ {
public: protected:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
// This is a reference to an instance of this element type.
static const TraitedType<T>& type;
TraitedType(const std::string& cname) TraitedType(const std::string& cname)
: Type(sizeof(T) * 8, : Type(sizeof(T) * 8,
std::is_floating_point<T>::value, std::is_floating_point<T>::value,
...@@ -69,15 +64,78 @@ namespace ngraph ...@@ -69,15 +64,78 @@ namespace ngraph
cname) cname)
{ {
} }
public:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
// This is a reference to an instance of this element type.
static const U& element_type(){
static U t;
return t;
}
}; };
// Human-readable names for the element types class Float : public TraitedType<float, Float>
using Float = TraitedType<float>; {
using Int8 = TraitedType<int8_t>; friend class TraitedType<float, Float>;
using Int32 = TraitedType<int32_t>; Float()
using Int64 = TraitedType<int64_t>; : TraitedType<float, Float>("float")
using UInt8 = TraitedType<uint8_t>; {
using UInt32 = TraitedType<uint32_t>; }
using UInt64 = TraitedType<uint64_t>; };
class Int8 : public TraitedType<int8_t, Int8>
{
friend class TraitedType<int8_t, Int8>;
Int8()
: TraitedType<int8_t, Int8>("int8_t")
{
}
};
class Int32 : public TraitedType<int32_t, Int32>
{
friend class TraitedType<int32_t, Int32>;
Int32()
: TraitedType<int32_t, Int32>("int32_t")
{
}
};
class Int64 : public TraitedType<int64_t, Int64>
{
friend class TraitedType<int64_t, Int64>;
Int64()
: TraitedType<int64_t, Int64>("int64_t")
{
}
};
class UInt8 : public TraitedType<uint8_t, UInt8>
{
friend class TraitedType<uint8_t, UInt8>;
UInt8()
: TraitedType<uint8_t, UInt8>("uint8_t")
{
}
};
class UInt32 : public TraitedType<uint32_t, UInt32>
{
friend class TraitedType<uint32_t, UInt32>;
UInt32()
: TraitedType<uint32_t, UInt32>("uint32_t")
{
}
};
class UInt64 : public TraitedType<uint64_t, UInt64>
{
friend class TraitedType<uint64_t, UInt64>;
UInt64()
: TraitedType<uint64_t, UInt64>("uint64_t")
{
}
};
} }
} }
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
namespace ngraph namespace ngraph
{ {
// Defines methods to all constant scalars // Defines methods to all constant scalars
class ScalarConstantBaseOp : public Node class ScalarConstantBase : public Node
{ {
protected: protected:
ScalarConstantBaseOp(const std::shared_ptr<TensorViewType>& type) ScalarConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type) : Node({}, type)
{ {
} }
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
// Implement a constant scalar for each element type. // Implement a constant scalar for each element type.
// The static make method takes a // The static make method takes a
template <typename T> template <typename T>
class ScalarConstantOp : public ScalarConstantBaseOp class ScalarConstant : public ScalarConstantBase
{ {
public: public:
// The ngraph element type // The ngraph element type
...@@ -41,8 +41,8 @@ namespace ngraph ...@@ -41,8 +41,8 @@ namespace ngraph
// The C++ type that holds the element type // The C++ type that holds the element type
using ctype = typename T::ctype; using ctype = typename T::ctype;
ScalarConstantOp(typename T::ctype value) ScalarConstant(typename T::ctype value)
: ScalarConstantBaseOp(std::make_shared<TensorViewType>(T::type, Shape{})) : ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value) , m_value(value)
{ {
} }
...@@ -54,20 +54,20 @@ namespace ngraph ...@@ -54,20 +54,20 @@ namespace ngraph
// Make a constant from any value that can be converted to the C++ type we use // Make a constant from any value that can be converted to the C++ type we use
// to represent the values. // to represent the values.
template <typename U> template <typename U>
static std::shared_ptr<ScalarConstantOp<T>> make(U value) static std::shared_ptr<ScalarConstant<T>> make(U value)
{ {
return std::make_shared<ScalarConstantOp<T>>(value); return std::make_shared<ScalarConstant<T>>(value);
} }
protected: protected:
typename T::ctype m_value; typename T::ctype m_value;
}; };
using FloatScalarConstantOp = ScalarConstantOp<element::Float>; using FloatScalarConstant = ScalarConstant<element::Float>;
using Int8ScalarConstantOp = ScalarConstantOp<element::Int8>; using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstantOp = ScalarConstantOp<element::Int32>; using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstantOp = ScalarConstantOp<element::Int64>; using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstantOp = ScalarConstantOp<element::UInt8>; using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstantOp = ScalarConstantOp<element::UInt32>; using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstantOp = ScalarConstantOp<element::UInt64>; using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
} }
...@@ -16,4 +16,4 @@ ...@@ -16,4 +16,4 @@
using namespace ngraph; using namespace ngraph;
void ScalarConstantBaseOp::propagate_types() {} void ScalarConstantBase::propagate_types() {}
...@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const ...@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{ {
return std::ceil((float)m_bitwidth / 8.0); return std::ceil((float)m_bitwidth / 8.0);
} }
namespace
{
const element::Float s_float32_t = element::Float{"float"};
const element::Int8 s_int8_t = element::Int8{"int8_t"};
const element::Int32 s_int32_t = element::Int32{"int32_t"};
const element::Int64 s_int64_t = element::Int64{"int64_t"};
const element::UInt8 s_uint8_t = element::UInt8{"uint8_t"};
const element::UInt32 s_uint32_t = element::UInt32{"uint32_t"};
const element::UInt64 s_uint64_t = element::UInt64{"uint64_t"};
}
template <>
const element::TraitedType<float>& element::TraitedType<float>::type = s_float32_t;
template <>
const element::TraitedType<int8_t>& element::TraitedType<int8_t>::type = s_int8_t;
template <>
const element::TraitedType<int32_t>& element::TraitedType<int32_t>::type = s_int32_t;
template <>
const element::TraitedType<int64_t>& element::TraitedType<int64_t>::type = s_int64_t;
template <>
const element::TraitedType<uint8_t>& element::TraitedType<uint8_t>::type = s_uint8_t;
template <>
const element::TraitedType<uint32_t>& element::TraitedType<uint32_t>::type = s_uint32_t;
template <>
const element::TraitedType<uint64_t>& element::TraitedType<uint64_t>::type = s_uint64_t;
\ No newline at end of file
...@@ -36,10 +36,10 @@ std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_typ ...@@ -36,10 +36,10 @@ std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_typ
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = myfun<Parameter>(element::Float::type, Shape{7, 3}); auto arg0 = op::parameter(element::Float::element_type(), Shape{7, 3});
auto arg1 = op::parameter(element::Float::type, Shape{3}); auto arg1 = op::parameter(element::Float::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::type, Shape{32, 7}); auto arg2 = op::parameter(element::Float::element_type(), Shape{32, 7});
auto arg3 = op::parameter(element::Float::type, Shape{32, 7}); auto arg3 = op::parameter(element::Float::element_type(), Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0}); auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0}); auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
...@@ -56,7 +56,7 @@ TEST(build_graph, build_simple) ...@@ -56,7 +56,7 @@ TEST(build_graph, build_simple)
TEST(build_graph, as_type) TEST(build_graph, as_type)
{ {
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::type, Shape{2, 3, 5}); ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt); auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv); ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt); auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
...@@ -73,14 +73,14 @@ TEST(build_graph, as_type) ...@@ -73,14 +73,14 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = op::parameter(element::Float::type, {32, 3}); auto arg0 = op::parameter(element::Float::element_type(), {32, 3});
auto arg1 = op::parameter(element::Float::type, {3}); auto arg1 = op::parameter(element::Float::element_type(), {3});
auto arg2 = op::parameter(element::Float::type, {32}); auto arg2 = op::parameter(element::Float::element_type(), {32});
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto parg = op::parameter(element::Float::type, {}); auto parg = op::parameter(element::Float::element_type(), {});
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
...@@ -91,8 +91,8 @@ TEST(build_graph, node_comparison) ...@@ -91,8 +91,8 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal) TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
auto float0 = FloatScalarConstantOp::make(3.0); auto float0 = FloatScalarConstant::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0); ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type); ASSERT_EQ(*float0->type(), float_scalar_type);
auto d = op::dot(float0, float0); auto d = op::dot(float0, float0);
...@@ -100,12 +100,12 @@ TEST(build_graph, literal) ...@@ -100,12 +100,12 @@ TEST(build_graph, literal)
ASSERT_EQ(d->arguments().at(1), float0); ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = FloatScalarConstantOp::make(3); auto float1 = FloatScalarConstant::make(3);
ASSERT_EQ(float1->value(), 3); ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type); ASSERT_EQ(*float1->type(), float_scalar_type);
auto int32_0 = Int32ScalarConstantOp::make(3.0); auto int32_0 = Int32ScalarConstant::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3); ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type); ASSERT_EQ(*int32_0->type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type); ASSERT_NE(*int32_0->type(), float_scalar_type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment