Commit ede578a6 authored by Bob Kimball's avatar Bob Kimball

wip

parent 494b16cd
...@@ -23,36 +23,37 @@ ...@@ -23,36 +23,37 @@
namespace ngraph namespace ngraph
{ {
class ElementType; namespace element
}
class ngraph::ElementType
{
public:
ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const
{ {
std::hash<std::string> h; class Type
return h(m_cname); {
public:
Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const
{
std::hash<std::string> h;
return h(m_cname);
}
bool operator==(const Type& other) const;
private:
static std::map<std::string, Type> m_element_list;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
};
const Type float32_t= Type(32, true, true, "float");
const Type int8_t = Type(8, false, true, "int8_t");
const Type int32_t = Type(32, false, true, "int32_t");
const Type int64_t = Type(64, false, true, "int64_t");
const Type uint8_t = Type(8, false, false, "int8_t");
const Type uint32_t = Type(32, false, false, "int32_t");
const Type uint64_t = Type(64, false, false, "int64_t");
} }
}
bool operator==(const ElementType& other) const;
private:
static std::map<std::string, ElementType> m_element_list;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
};
extern const ngraph::ElementType element_type_float;
extern const ngraph::ElementType element_type_int8_t;
extern const ngraph::ElementType element_type_int32_t;
extern const ngraph::ElementType element_type_int64_t;
extern const ngraph::ElementType element_type_uint8_t;
extern const ngraph::ElementType element_type_uint32_t;
extern const ngraph::ElementType element_type_uint64_t;
...@@ -51,14 +51,14 @@ namespace ngraph ...@@ -51,14 +51,14 @@ namespace ngraph
** /param element_type The type of the tensor elements. ** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor. ** /param shape The shape of the tensor.
**/ **/
TensorViewType(const ElementType& element_type, const Shape& shape) TensorViewType(const element::Type& element_type, const Shape& shape)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(shape) , m_shape(shape)
{ {
} }
protected: protected:
const ElementType& m_element_type; const element::Type& m_element_type;
Shape m_shape; Shape m_shape;
}; };
...@@ -115,7 +115,7 @@ namespace ngraph ...@@ -115,7 +115,7 @@ namespace ngraph
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
** /param shape The shape of the view ** /param shape The shape of the view
**/ **/
void type(const ElementType& element_type, const Shape& shape) void type(const element::Type& element_type, const Shape& shape)
{ {
m_type = std::make_shared<TensorViewType>(element_type, shape); m_type = std::make_shared<TensorViewType>(element_type, shape);
} }
......
...@@ -17,17 +17,9 @@ ...@@ -17,17 +17,9 @@
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float"); std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t");
const ngraph::ElementType element_type_int32_t = ngraph::ElementType(32, false, true, "int32_t");
const ngraph::ElementType element_type_int64_t = ngraph::ElementType(64, false, true, "int64_t");
const ngraph::ElementType element_type_uint8_t = ngraph::ElementType(8, false, false, "int8_t");
const ngraph::ElementType element_type_uint32_t = ngraph::ElementType(32, false, false, "int32_t");
const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false, false, "int64_t");
std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list; ngraph::element::Type::Type(size_t bitwidth,
ngraph::ElementType::ElementType(size_t bitwidth,
bool is_float, bool is_float,
bool is_signed, bool is_signed,
const std::string& cname) const std::string& cname)
...@@ -39,18 +31,18 @@ ngraph::ElementType::ElementType(size_t bitwidth, ...@@ -39,18 +31,18 @@ ngraph::ElementType::ElementType(size_t bitwidth,
assert(m_bitwidth % 8 == 0); assert(m_bitwidth % 8 == 0);
} }
const std::string& ngraph::ElementType::c_type_string() const const std::string& ngraph::element::Type::c_type_string() const
{ {
return m_cname; return m_cname;
} }
bool ngraph::ElementType::operator==(const ElementType& other) const bool ngraph::element::Type::operator==(const element::Type& other) const
{ {
return m_bitwidth == other.m_bitwidth && m_is_float == other.m_is_float && return m_bitwidth == other.m_bitwidth && m_is_float == other.m_is_float &&
m_is_signed == other.m_is_signed; m_is_signed == other.m_is_signed;
} }
size_t ngraph::ElementType::size() const size_t ngraph::element::Type::size() const
{ {
return std::ceil((float)m_bitwidth / 8.0); return std::ceil((float)m_bitwidth / 8.0);
} }
...@@ -23,11 +23,11 @@ TEST(graph, build_simple) ...@@ -23,11 +23,11 @@ TEST(graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto cluster_0 = make_shared<Function>(4); auto cluster_0 = make_shared<Function>(4);
cluster_0->result()->type(element_type_float, {32, 3}); cluster_0->result()->type(element::float32_t, {32, 3});
cluster_0->parameter(0)->type(element_type_float, {7, 3}); cluster_0->parameter(0)->type(element::float32_t, {7, 3});
cluster_0->parameter(1)->type(element_type_float, {3}); cluster_0->parameter(1)->type(element::float32_t, {3});
cluster_0->parameter(2)->type(element_type_float, {32, 7}); cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element_type_float, {32, 7}); cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3); auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1. // call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1); auto broadcast_1 = op::broadcast(arg3, 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment