Commit ede578a6 authored by Bob Kimball's avatar Bob Kimball

wip

parent 494b16cd
......@@ -23,13 +23,12 @@
namespace ngraph
{
class ElementType;
}
class ngraph::ElementType
{
public:
ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
namespace element
{
class Type
{
public:
Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const;
size_t size() const;
......@@ -39,20 +38,22 @@ public:
return h(m_cname);
}
bool operator==(const ElementType& other) const;
bool operator==(const Type& other) const;
private:
static std::map<std::string, ElementType> m_element_list;
private:
static std::map<std::string, Type> m_element_list;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
};
extern const ngraph::ElementType element_type_float;
extern const ngraph::ElementType element_type_int8_t;
extern const ngraph::ElementType element_type_int32_t;
extern const ngraph::ElementType element_type_int64_t;
extern const ngraph::ElementType element_type_uint8_t;
extern const ngraph::ElementType element_type_uint32_t;
extern const ngraph::ElementType element_type_uint64_t;
};
const Type float32_t= Type(32, true, true, "float");
const Type int8_t = Type(8, false, true, "int8_t");
const Type int32_t = Type(32, false, true, "int32_t");
const Type int64_t = Type(64, false, true, "int64_t");
const Type uint8_t = Type(8, false, false, "int8_t");
const Type uint32_t = Type(32, false, false, "int32_t");
const Type uint64_t = Type(64, false, false, "int64_t");
}
}
......@@ -51,14 +51,14 @@ namespace ngraph
** /param element_type The type of the tensor elements.
** /param shape The shape of the tensor.
**/
TensorViewType(const ElementType& element_type, const Shape& shape)
TensorViewType(const element::Type& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{
}
protected:
const ElementType& m_element_type;
const element::Type& m_element_type;
Shape m_shape;
};
......@@ -115,7 +115,7 @@ namespace ngraph
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const ElementType& element_type, const Shape& shape)
void type(const element::Type& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
......
......@@ -17,17 +17,9 @@
#include "ngraph/element_type.hpp"
const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float");
const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t");
const ngraph::ElementType element_type_int32_t = ngraph::ElementType(32, false, true, "int32_t");
const ngraph::ElementType element_type_int64_t = ngraph::ElementType(64, false, true, "int64_t");
const ngraph::ElementType element_type_uint8_t = ngraph::ElementType(8, false, false, "int8_t");
const ngraph::ElementType element_type_uint32_t = ngraph::ElementType(32, false, false, "int32_t");
const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false, false, "int64_t");
std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list;
ngraph::ElementType::ElementType(size_t bitwidth,
ngraph::element::Type::Type(size_t bitwidth,
bool is_float,
bool is_signed,
const std::string& cname)
......@@ -39,18 +31,18 @@ ngraph::ElementType::ElementType(size_t bitwidth,
assert(m_bitwidth % 8 == 0);
}
const std::string& ngraph::ElementType::c_type_string() const
const std::string& ngraph::element::Type::c_type_string() const
{
return m_cname;
}
bool ngraph::ElementType::operator==(const ElementType& other) const
bool ngraph::element::Type::operator==(const element::Type& other) const
{
return m_bitwidth == other.m_bitwidth && m_is_float == other.m_is_float &&
m_is_signed == other.m_is_signed;
}
size_t ngraph::ElementType::size() const
size_t ngraph::element::Type::size() const
{
return std::ceil((float)m_bitwidth / 8.0);
}
......@@ -23,11 +23,11 @@ TEST(graph, build_simple)
{
// Function with 4 parameters
auto cluster_0 = make_shared<Function>(4);
cluster_0->result()->type(element_type_float, {32, 3});
cluster_0->parameter(0)->type(element_type_float, {7, 3});
cluster_0->parameter(1)->type(element_type_float, {3});
cluster_0->parameter(2)->type(element_type_float, {32, 7});
cluster_0->parameter(3)->type(element_type_float, {32, 7});
cluster_0->result()->type(element::float32_t, {32, 3});
cluster_0->parameter(0)->type(element::float32_t, {7, 3});
cluster_0->parameter(1)->type(element::float32_t, {3});
cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment