Commit 3c6ab287 authored by Scott Cyphers's avatar Scott Cyphers

Review fixes

parent 6d6e923b
......@@ -38,7 +38,7 @@ namespace ngraph
using Nodes = std::vector<std::shared_ptr<Node>>;
/// A sequence of axes
using AxisList = std::vector<size_t>;
using AxisVector = std::vector<size_t>;
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
......
......@@ -52,13 +52,18 @@ namespace ngraph
};
// Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name.
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
template<typename T>
const char* traited_type_name()
{
throw ngraph_error("Unkmown type");
throw ngraph_error("Unknown type");
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN( T ) \
template<> constexpr const char* traited_type_name < T > () { return #T; }
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
......@@ -84,61 +89,25 @@ namespace ngraph
}
};
template<>
constexpr const char* traited_type_name<float>()
{
return "float";
}
NGRAPH_DEFINE_TTN( float )
using Float = TraitedType<float>;
template<>
constexpr const char* traited_type_name<int8_t>()
{
return "int8_t";
}
NGRAPH_DEFINE_TTN( int8_t )
using Int8 = TraitedType<int8_t>;
template<>
constexpr const char* traited_type_name<int32_t>()
{
return "int32_t";
}
NGRAPH_DEFINE_TTN( int32_t )
using Int32 = TraitedType<int32_t>;
template<>
constexpr const char* traited_type_name<int64_t>()
{
return "int64_t";
}
NGRAPH_DEFINE_TTN( int64_t )
using Int64 = TraitedType<int64_t>;
template<>
constexpr const char* traited_type_name<uint8_t>()
{
return "uint8_t";
}
NGRAPH_DEFINE_TTN( uint8_t )
using UInt8 = TraitedType<uint8_t>;
template<>
constexpr const char* traited_type_name<uint32_t>()
{
return "uint32_t";
}
NGRAPH_DEFINE_TTN( uint32_t )
using UInt32 = TraitedType<uint32_t>;
template<>
constexpr const char* traited_type_name<uint64_t>()
{
return "uint64_t";
}
NGRAPH_DEFINE_TTN( uint64_t )
using UInt64 = TraitedType<uint64_t>;
}
}
......@@ -25,7 +25,7 @@ namespace ngraph
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, AxisSet& broadcast_axes)
BroadcastOp(const Node::ptr& arg, const Shape& shape, const AxisSet& broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
......
......@@ -54,7 +54,7 @@ namespace ngraph
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type,
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment