Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
f106a582
Commit
f106a582
authored
Aug 31, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rework element type to not depend on static initializers
parent
c21b3f88
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
106 additions
and
74 deletions
+106
-74
element_type.hpp
src/ngraph/element_type.hpp
+77
-19
constant.hpp
src/ngraph/ops/constant.hpp
+14
-14
constant.cpp
src/ops/constant.cpp
+1
-1
element_type.cpp
src/types/element_type.cpp
+0
-26
build_graph.cpp
test/build_graph.cpp
+14
-14
No files found.
src/ngraph/element_type.hpp
View file @
f106a582
...
...
@@ -39,8 +39,8 @@ namespace ngraph
return
h
(
m_cname
);
}
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
//
bool operator==(const Type& other) const;
//
bool operator!=(const Type& other) const { return !(*this == other); }
private
:
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
...
...
@@ -53,31 +53,89 @@ namespace ngraph
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
template
<
typename
T
>
template
<
typename
T
,
typename
U
>
class
TraitedType
:
public
Type
{
protected
:
TraitedType
(
const
std
::
string
&
cname
)
:
Type
(
sizeof
(
T
)
*
8
,
std
::
is_floating_point
<
T
>::
value
,
std
::
is_signed
<
T
>::
value
,
cname
)
{
}
public
:
// This is the C++ type used to hold a value of this element type during compilation
using
ctype
=
T
;
// This is a reference to an instance of this element type.
static
const
TraitedType
<
T
>&
type
;
TraitedType
(
const
std
::
string
&
cname
)
:
Type
(
sizeof
(
T
)
*
8
,
std
::
is_floating_point
<
T
>::
value
,
std
::
is_signed
<
T
>::
value
,
cname
)
{
static
const
U
&
element_type
(){
static
U
t
;
return
t
;
}
};
// Human-readable names for the element types
using
Float
=
TraitedType
<
float
>
;
using
Int8
=
TraitedType
<
int8_t
>
;
using
Int32
=
TraitedType
<
int32_t
>
;
using
Int64
=
TraitedType
<
int64_t
>
;
using
UInt8
=
TraitedType
<
uint8_t
>
;
using
UInt32
=
TraitedType
<
uint32_t
>
;
using
UInt64
=
TraitedType
<
uint64_t
>
;
class
Float
:
public
TraitedType
<
float
,
Float
>
{
friend
class
TraitedType
<
float
,
Float
>
;
Float
()
:
TraitedType
<
float
,
Float
>
(
"float"
)
{
}
};
class
Int8
:
public
TraitedType
<
int8_t
,
Int8
>
{
friend
class
TraitedType
<
int8_t
,
Int8
>
;
Int8
()
:
TraitedType
<
int8_t
,
Int8
>
(
"int8_t"
)
{
}
};
class
Int32
:
public
TraitedType
<
int32_t
,
Int32
>
{
friend
class
TraitedType
<
int32_t
,
Int32
>
;
Int32
()
:
TraitedType
<
int32_t
,
Int32
>
(
"int32_t"
)
{
}
};
class
Int64
:
public
TraitedType
<
int64_t
,
Int64
>
{
friend
class
TraitedType
<
int64_t
,
Int64
>
;
Int64
()
:
TraitedType
<
int64_t
,
Int64
>
(
"int64_t"
)
{
}
};
class
UInt8
:
public
TraitedType
<
uint8_t
,
UInt8
>
{
friend
class
TraitedType
<
uint8_t
,
UInt8
>
;
UInt8
()
:
TraitedType
<
uint8_t
,
UInt8
>
(
"uint8_t"
)
{
}
};
class
UInt32
:
public
TraitedType
<
uint32_t
,
UInt32
>
{
friend
class
TraitedType
<
uint32_t
,
UInt32
>
;
UInt32
()
:
TraitedType
<
uint32_t
,
UInt32
>
(
"uint32_t"
)
{
}
};
class
UInt64
:
public
TraitedType
<
uint64_t
,
UInt64
>
{
friend
class
TraitedType
<
uint64_t
,
UInt64
>
;
UInt64
()
:
TraitedType
<
uint64_t
,
UInt64
>
(
"uint64_t"
)
{
}
};
}
}
src/ngraph/ops/constant.hpp
View file @
f106a582
...
...
@@ -19,10 +19,10 @@
namespace
ngraph
{
// Defines methods to all constant scalars
class
ScalarConstantBase
Op
:
public
Node
class
ScalarConstantBase
:
public
Node
{
protected
:
ScalarConstantBase
Op
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
ScalarConstantBase
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
:
Node
({},
type
)
{
}
...
...
@@ -33,7 +33,7 @@ namespace ngraph
// Implement a constant scalar for each element type.
// The static make method takes a
template
<
typename
T
>
class
ScalarConstant
Op
:
public
ScalarConstantBaseOp
class
ScalarConstant
:
public
ScalarConstantBase
{
public
:
// The ngraph element type
...
...
@@ -41,8 +41,8 @@ namespace ngraph
// The C++ type that holds the element type
using
ctype
=
typename
T
::
ctype
;
ScalarConstant
Op
(
typename
T
::
ctype
value
)
:
ScalarConstantBase
Op
(
std
::
make_shared
<
TensorViewType
>
(
T
::
type
,
Shape
{}))
ScalarConstant
(
typename
T
::
ctype
value
)
:
ScalarConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
()
,
Shape
{}))
,
m_value
(
value
)
{
}
...
...
@@ -54,20 +54,20 @@ namespace ngraph
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template
<
typename
U
>
static
std
::
shared_ptr
<
ScalarConstant
Op
<
T
>>
make
(
U
value
)
static
std
::
shared_ptr
<
ScalarConstant
<
T
>>
make
(
U
value
)
{
return
std
::
make_shared
<
ScalarConstant
Op
<
T
>>
(
value
);
return
std
::
make_shared
<
ScalarConstant
<
T
>>
(
value
);
}
protected
:
typename
T
::
ctype
m_value
;
};
using
FloatScalarConstant
Op
=
ScalarConstantOp
<
element
::
Float
>
;
using
Int8ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int8
>
;
using
Int32ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int32
>
;
using
Int64ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int64
>
;
using
UInt8ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt64
>
;
using
FloatScalarConstant
=
ScalarConstant
<
element
::
Float
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
using
UInt8ScalarConstant
=
ScalarConstant
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
=
ScalarConstant
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
=
ScalarConstant
<
element
::
UInt64
>
;
}
src/ops/constant.cpp
View file @
f106a582
...
...
@@ -16,4 +16,4 @@
using
namespace
ngraph
;
void
ScalarConstantBase
Op
::
propagate_types
()
{}
void
ScalarConstantBase
::
propagate_types
()
{}
src/types/element_type.cpp
View file @
f106a582
...
...
@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{
return
std
::
ceil
((
float
)
m_bitwidth
/
8.0
);
}
namespace
{
const
element
::
Float
s_float32_t
=
element
::
Float
{
"float"
};
const
element
::
Int8
s_int8_t
=
element
::
Int8
{
"int8_t"
};
const
element
::
Int32
s_int32_t
=
element
::
Int32
{
"int32_t"
};
const
element
::
Int64
s_int64_t
=
element
::
Int64
{
"int64_t"
};
const
element
::
UInt8
s_uint8_t
=
element
::
UInt8
{
"uint8_t"
};
const
element
::
UInt32
s_uint32_t
=
element
::
UInt32
{
"uint32_t"
};
const
element
::
UInt64
s_uint64_t
=
element
::
UInt64
{
"uint64_t"
};
}
template
<>
const
element
::
TraitedType
<
float
>&
element
::
TraitedType
<
float
>::
type
=
s_float32_t
;
template
<>
const
element
::
TraitedType
<
int8_t
>&
element
::
TraitedType
<
int8_t
>::
type
=
s_int8_t
;
template
<>
const
element
::
TraitedType
<
int32_t
>&
element
::
TraitedType
<
int32_t
>::
type
=
s_int32_t
;
template
<>
const
element
::
TraitedType
<
int64_t
>&
element
::
TraitedType
<
int64_t
>::
type
=
s_int64_t
;
template
<>
const
element
::
TraitedType
<
uint8_t
>&
element
::
TraitedType
<
uint8_t
>::
type
=
s_uint8_t
;
template
<>
const
element
::
TraitedType
<
uint32_t
>&
element
::
TraitedType
<
uint32_t
>::
type
=
s_uint32_t
;
template
<>
const
element
::
TraitedType
<
uint64_t
>&
element
::
TraitedType
<
uint64_t
>::
type
=
s_uint64_t
;
\ No newline at end of file
test/build_graph.cpp
View file @
f106a582
...
...
@@ -36,10 +36,10 @@ std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_typ
TEST
(
build_graph
,
build_simple
)
{
// Function with 4 parameters
auto
arg0
=
myfun
<
Parameter
>
(
element
::
Float
::
type
,
Shape
{
7
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
Shape
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
type
,
Shape
{
32
,
7
});
auto
arg3
=
op
::
parameter
(
element
::
Float
::
type
,
Shape
{
32
,
7
});
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
7
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
32
,
7
});
auto
arg3
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
32
,
7
});
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
b1
=
myfun
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
...
...
@@ -56,7 +56,7 @@ TEST(build_graph, build_simple)
TEST
(
build_graph
,
as_type
)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType
::
ptr
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
type
,
Shape
{
2
,
3
,
5
});
ValueType
::
ptr
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
()
,
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
...
...
@@ -73,14 +73,14 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
});
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{
32
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{
32
});
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
parg
=
op
::
parameter
(
element
::
Float
::
type
,
{});
auto
parg
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{});
auto
pattern_dot
=
op
::
dot
(
parg
,
parg
);
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
// TODO This passes because typeid is not behaving as documented.
...
...
@@ -91,8 +91,8 @@ TEST(build_graph, node_comparison)
TEST
(
build_graph
,
literal
)
{
// float scalar from a float
auto
float0
=
FloatScalarConstant
Op
::
make
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
type
,
Shape
{});
auto
float0
=
FloatScalarConstant
::
make
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
()
,
Shape
{});
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
type
(),
float_scalar_type
);
auto
d
=
op
::
dot
(
float0
,
float0
);
...
...
@@ -100,12 +100,12 @@ TEST(build_graph, literal)
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
// float scalar from an int
auto
float1
=
FloatScalarConstant
Op
::
make
(
3
);
auto
float1
=
FloatScalarConstant
::
make
(
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
*
float1
->
type
(),
float_scalar_type
);
auto
int32_0
=
Int32ScalarConstant
Op
::
make
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
type
,
Shape
{});
auto
int32_0
=
Int32ScalarConstant
::
make
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
()
,
Shape
{});
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
type
(),
float_scalar_type
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment