Commit 3c6ab287 authored by Scott Cyphers's avatar Scott Cyphers

Review fixes

parent 6d6e923b
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
using Nodes = std::vector<std::shared_ptr<Node>>; using Nodes = std::vector<std::shared_ptr<Node>>;
/// A sequence of axes /// A sequence of axes
using AxisList = std::vector<size_t>; using AxisVector = std::vector<size_t>;
/// A set of axes, for example, reduction axes /// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>; using AxisSet = std::set<size_t>;
......
...@@ -52,13 +52,18 @@ namespace ngraph ...@@ -52,13 +52,18 @@ namespace ngraph
}; };
// Provides a compile-time name for a C++ type. // Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name. // Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
template<typename T> template<typename T>
const char* traited_type_name() const char* traited_type_name()
{ {
throw ngraph_error("Unkmown type"); throw ngraph_error("Unknown type");
} }
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN( T ) \
template<> constexpr const char* traited_type_name < T > () { return #T; }
// Literals (and probably other things we don't know about yet) need to have their C++ types // Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides // and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation. // access to both the instance and the C++ type used to hold the value during compilation.
...@@ -84,61 +89,25 @@ namespace ngraph ...@@ -84,61 +89,25 @@ namespace ngraph
} }
}; };
template<> NGRAPH_DEFINE_TTN( float )
constexpr const char* traited_type_name<float>()
{
return "float";
}
using Float = TraitedType<float>; using Float = TraitedType<float>;
NGRAPH_DEFINE_TTN( int8_t )
template<>
constexpr const char* traited_type_name<int8_t>()
{
return "int8_t";
}
using Int8 = TraitedType<int8_t>; using Int8 = TraitedType<int8_t>;
template<> NGRAPH_DEFINE_TTN( int32_t )
constexpr const char* traited_type_name<int32_t>()
{
return "int32_t";
}
using Int32 = TraitedType<int32_t>; using Int32 = TraitedType<int32_t>;
template<> NGRAPH_DEFINE_TTN( int64_t )
constexpr const char* traited_type_name<int64_t>()
{
return "int64_t";
}
using Int64 = TraitedType<int64_t>; using Int64 = TraitedType<int64_t>;
template<> NGRAPH_DEFINE_TTN( uint8_t )
constexpr const char* traited_type_name<uint8_t>()
{
return "uint8_t";
}
using UInt8 = TraitedType<uint8_t>; using UInt8 = TraitedType<uint8_t>;
template<> NGRAPH_DEFINE_TTN( uint32_t )
constexpr const char* traited_type_name<uint32_t>()
{
return "uint32_t";
}
using UInt32 = TraitedType<uint32_t>; using UInt32 = TraitedType<uint32_t>;
template<> NGRAPH_DEFINE_TTN( uint64_t )
constexpr const char* traited_type_name<uint64_t>()
{
return "uint64_t";
}
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. ** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg. ** the remaining axes in shape must be the same as the shape of arg.
**/ **/
BroadcastOp(const Node::ptr& arg, const Shape& shape, AxisSet& broadcast_axes) BroadcastOp(const Node::ptr& arg, const Shape& shape, const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
......
...@@ -54,7 +54,7 @@ namespace ngraph ...@@ -54,7 +54,7 @@ namespace ngraph
/// Factory for frameworks /// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr); std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr);
/// Convenience factory for tests /// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type, std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape); const Shape& shape);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment