Commit f106a582 authored by Scott Cyphers's avatar Scott Cyphers

Rework element type to not depend on static initializers

parent c21b3f88
......@@ -39,8 +39,8 @@ namespace ngraph
return h(m_cname);
}
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
//bool operator==(const Type& other) const;
//bool operator!=(const Type& other) const { return !(*this == other); }
private:
static std::map<std::string, Type> m_element_list;
......@@ -53,31 +53,89 @@ namespace ngraph
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
template <typename T>
template <typename T, typename U>
class TraitedType : public Type
{
protected:
TraitedType(const std::string& cname)
: Type(sizeof(T) * 8,
std::is_floating_point<T>::value,
std::is_signed<T>::value,
cname)
{
}
public:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
// This is a reference to an instance of this element type.
static const TraitedType<T>& type;
TraitedType(const std::string& cname)
: Type(sizeof(T) * 8,
std::is_floating_point<T>::value,
std::is_signed<T>::value,
cname)
{
static const U& element_type(){
static U t;
return t;
}
};
// Human-readable names for the element types
using Float = TraitedType<float>;
using Int8 = TraitedType<int8_t>;
using Int32 = TraitedType<int32_t>;
using Int64 = TraitedType<int64_t>;
using UInt8 = TraitedType<uint8_t>;
using UInt32 = TraitedType<uint32_t>;
using UInt64 = TraitedType<uint64_t>;
class Float : public TraitedType<float, Float>
{
friend class TraitedType<float, Float>;
Float()
: TraitedType<float, Float>("float")
{
}
};
class Int8 : public TraitedType<int8_t, Int8>
{
friend class TraitedType<int8_t, Int8>;
Int8()
: TraitedType<int8_t, Int8>("int8_t")
{
}
};
class Int32 : public TraitedType<int32_t, Int32>
{
friend class TraitedType<int32_t, Int32>;
Int32()
: TraitedType<int32_t, Int32>("int32_t")
{
}
};
class Int64 : public TraitedType<int64_t, Int64>
{
friend class TraitedType<int64_t, Int64>;
Int64()
: TraitedType<int64_t, Int64>("int64_t")
{
}
};
class UInt8 : public TraitedType<uint8_t, UInt8>
{
friend class TraitedType<uint8_t, UInt8>;
UInt8()
: TraitedType<uint8_t, UInt8>("uint8_t")
{
}
};
class UInt32 : public TraitedType<uint32_t, UInt32>
{
friend class TraitedType<uint32_t, UInt32>;
UInt32()
: TraitedType<uint32_t, UInt32>("uint32_t")
{
}
};
class UInt64 : public TraitedType<uint64_t, UInt64>
{
friend class TraitedType<uint64_t, UInt64>;
UInt64()
: TraitedType<uint64_t, UInt64>("uint64_t")
{
}
};
}
}
......@@ -19,10 +19,10 @@
namespace ngraph
{
// Defines methods to all constant scalars
class ScalarConstantBaseOp : public Node
class ScalarConstantBase : public Node
{
protected:
ScalarConstantBaseOp(const std::shared_ptr<TensorViewType>& type)
ScalarConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
......@@ -33,7 +33,7 @@ namespace ngraph
// Implement a constant scalar for each element type.
// The static make method takes a
template <typename T>
class ScalarConstantOp : public ScalarConstantBaseOp
class ScalarConstant : public ScalarConstantBase
{
public:
// The ngraph element type
......@@ -41,8 +41,8 @@ namespace ngraph
// The C++ type that holds the element type
using ctype = typename T::ctype;
ScalarConstantOp(typename T::ctype value)
: ScalarConstantBaseOp(std::make_shared<TensorViewType>(T::type, Shape{}))
ScalarConstant(typename T::ctype value)
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value)
{
}
......@@ -54,20 +54,20 @@ namespace ngraph
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarConstantOp<T>> make(U value)
static std::shared_ptr<ScalarConstant<T>> make(U value)
{
return std::make_shared<ScalarConstantOp<T>>(value);
return std::make_shared<ScalarConstant<T>>(value);
}
protected:
typename T::ctype m_value;
};
using FloatScalarConstantOp = ScalarConstantOp<element::Float>;
using Int8ScalarConstantOp = ScalarConstantOp<element::Int8>;
using Int32ScalarConstantOp = ScalarConstantOp<element::Int32>;
using Int64ScalarConstantOp = ScalarConstantOp<element::Int64>;
using UInt8ScalarConstantOp = ScalarConstantOp<element::UInt8>;
using UInt32ScalarConstantOp = ScalarConstantOp<element::UInt32>;
using UInt64ScalarConstantOp = ScalarConstantOp<element::UInt64>;
using FloatScalarConstant = ScalarConstant<element::Float>;
using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
}
......@@ -16,4 +16,4 @@
using namespace ngraph;
void ScalarConstantBaseOp::propagate_types() {}
void ScalarConstantBase::propagate_types() {}
......@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{
return std::ceil((float)m_bitwidth / 8.0);
}
namespace
{
const element::Float s_float32_t = element::Float{"float"};
const element::Int8 s_int8_t = element::Int8{"int8_t"};
const element::Int32 s_int32_t = element::Int32{"int32_t"};
const element::Int64 s_int64_t = element::Int64{"int64_t"};
const element::UInt8 s_uint8_t = element::UInt8{"uint8_t"};
const element::UInt32 s_uint32_t = element::UInt32{"uint32_t"};
const element::UInt64 s_uint64_t = element::UInt64{"uint64_t"};
}
template <>
const element::TraitedType<float>& element::TraitedType<float>::type = s_float32_t;
template <>
const element::TraitedType<int8_t>& element::TraitedType<int8_t>::type = s_int8_t;
template <>
const element::TraitedType<int32_t>& element::TraitedType<int32_t>::type = s_int32_t;
template <>
const element::TraitedType<int64_t>& element::TraitedType<int64_t>::type = s_int64_t;
template <>
const element::TraitedType<uint8_t>& element::TraitedType<uint8_t>::type = s_uint8_t;
template <>
const element::TraitedType<uint32_t>& element::TraitedType<uint32_t>::type = s_uint32_t;
template <>
const element::TraitedType<uint64_t>& element::TraitedType<uint64_t>::type = s_uint64_t;
\ No newline at end of file
......@@ -36,10 +36,10 @@ std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_typ
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = myfun<Parameter>(element::Float::type, Shape{7, 3});
auto arg1 = op::parameter(element::Float::type, Shape{3});
auto arg2 = op::parameter(element::Float::type, Shape{32, 7});
auto arg3 = op::parameter(element::Float::type, Shape{32, 7});
auto arg0 = op::parameter(element::Float::element_type(), Shape{7, 3});
auto arg1 = op::parameter(element::Float::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::element_type(), Shape{32, 7});
auto arg3 = op::parameter(element::Float::element_type(), Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto dot = op::dot(arg2, arg0);
......@@ -56,7 +56,7 @@ TEST(build_graph, build_simple)
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::type, Shape{2, 3, 5});
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
......@@ -73,14 +73,14 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto arg0 = op::parameter(element::Float::type, {32, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32});
auto arg0 = op::parameter(element::Float::element_type(), {32, 3});
auto arg1 = op::parameter(element::Float::element_type(), {3});
auto arg2 = op::parameter(element::Float::element_type(), {32});
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto parg = op::parameter(element::Float::type, {});
auto parg = op::parameter(element::Float::element_type(), {});
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
......@@ -91,8 +91,8 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal)
{
// float scalar from a float
auto float0 = FloatScalarConstantOp::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{});
auto float0 = FloatScalarConstant::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type);
auto d = op::dot(float0, float0);
......@@ -100,12 +100,12 @@ TEST(build_graph, literal)
ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int
auto float1 = FloatScalarConstantOp::make(3);
auto float1 = FloatScalarConstant::make(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type);
auto int32_0 = Int32ScalarConstantOp::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{});
auto int32_0 = Int32ScalarConstant::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment