Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
f106a582
Commit
f106a582
authored
Aug 31, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rework element type to not depend on static initializers
parent
c21b3f88
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
104 additions
and
72 deletions
+104
-72
element_type.hpp
src/ngraph/element_type.hpp
+75
-17
constant.hpp
src/ngraph/ops/constant.hpp
+14
-14
constant.cpp
src/ops/constant.cpp
+1
-1
element_type.cpp
src/types/element_type.cpp
+0
-26
build_graph.cpp
test/build_graph.cpp
+14
-14
No files found.
src/ngraph/element_type.hpp
View file @
f106a582
...
@@ -39,8 +39,8 @@ namespace ngraph
...
@@ -39,8 +39,8 @@ namespace ngraph
return
h
(
m_cname
);
return
h
(
m_cname
);
}
}
bool
operator
==
(
const
Type
&
other
)
const
;
//
bool operator==(const Type& other) const;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
//
bool operator!=(const Type& other) const { return !(*this == other); }
private
:
private
:
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
static
std
::
map
<
std
::
string
,
Type
>
m_element_list
;
...
@@ -53,15 +53,10 @@ namespace ngraph
...
@@ -53,15 +53,10 @@ namespace ngraph
// Literals (and probably other things we don't know about yet) need to have their C++ types
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
// access to both the instance and the C++ type used to hold the value during compilation.
template
<
typename
T
>
template
<
typename
T
,
typename
U
>
class
TraitedType
:
public
Type
class
TraitedType
:
public
Type
{
{
public
:
protected
:
// This is the C++ type used to hold a value of this element type during compilation
using
ctype
=
T
;
// This is a reference to an instance of this element type.
static
const
TraitedType
<
T
>&
type
;
TraitedType
(
const
std
::
string
&
cname
)
TraitedType
(
const
std
::
string
&
cname
)
:
Type
(
sizeof
(
T
)
*
8
,
:
Type
(
sizeof
(
T
)
*
8
,
std
::
is_floating_point
<
T
>::
value
,
std
::
is_floating_point
<
T
>::
value
,
...
@@ -69,15 +64,78 @@ namespace ngraph
...
@@ -69,15 +64,78 @@ namespace ngraph
cname
)
cname
)
{
{
}
}
public
:
// This is the C++ type used to hold a value of this element type during compilation
using
ctype
=
T
;
// This is a reference to an instance of this element type.
static
const
U
&
element_type
(){
static
U
t
;
return
t
;
}
};
};
// Human-readable names for the element types
class
Float
:
public
TraitedType
<
float
,
Float
>
using
Float
=
TraitedType
<
float
>
;
{
using
Int8
=
TraitedType
<
int8_t
>
;
friend
class
TraitedType
<
float
,
Float
>
;
using
Int32
=
TraitedType
<
int32_t
>
;
Float
()
using
Int64
=
TraitedType
<
int64_t
>
;
:
TraitedType
<
float
,
Float
>
(
"float"
)
using
UInt8
=
TraitedType
<
uint8_t
>
;
{
using
UInt32
=
TraitedType
<
uint32_t
>
;
}
using
UInt64
=
TraitedType
<
uint64_t
>
;
};
class
Int8
:
public
TraitedType
<
int8_t
,
Int8
>
{
friend
class
TraitedType
<
int8_t
,
Int8
>
;
Int8
()
:
TraitedType
<
int8_t
,
Int8
>
(
"int8_t"
)
{
}
};
class
Int32
:
public
TraitedType
<
int32_t
,
Int32
>
{
friend
class
TraitedType
<
int32_t
,
Int32
>
;
Int32
()
:
TraitedType
<
int32_t
,
Int32
>
(
"int32_t"
)
{
}
};
class
Int64
:
public
TraitedType
<
int64_t
,
Int64
>
{
friend
class
TraitedType
<
int64_t
,
Int64
>
;
Int64
()
:
TraitedType
<
int64_t
,
Int64
>
(
"int64_t"
)
{
}
};
class
UInt8
:
public
TraitedType
<
uint8_t
,
UInt8
>
{
friend
class
TraitedType
<
uint8_t
,
UInt8
>
;
UInt8
()
:
TraitedType
<
uint8_t
,
UInt8
>
(
"uint8_t"
)
{
}
};
class
UInt32
:
public
TraitedType
<
uint32_t
,
UInt32
>
{
friend
class
TraitedType
<
uint32_t
,
UInt32
>
;
UInt32
()
:
TraitedType
<
uint32_t
,
UInt32
>
(
"uint32_t"
)
{
}
};
class
UInt64
:
public
TraitedType
<
uint64_t
,
UInt64
>
{
friend
class
TraitedType
<
uint64_t
,
UInt64
>
;
UInt64
()
:
TraitedType
<
uint64_t
,
UInt64
>
(
"uint64_t"
)
{
}
};
}
}
}
}
src/ngraph/ops/constant.hpp
View file @
f106a582
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
namespace
ngraph
namespace
ngraph
{
{
// Defines methods to all constant scalars
// Defines methods to all constant scalars
class
ScalarConstantBase
Op
:
public
Node
class
ScalarConstantBase
:
public
Node
{
{
protected
:
protected
:
ScalarConstantBase
Op
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
ScalarConstantBase
(
const
std
::
shared_ptr
<
TensorViewType
>&
type
)
:
Node
({},
type
)
:
Node
({},
type
)
{
{
}
}
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
// Implement a constant scalar for each element type.
// Implement a constant scalar for each element type.
// The static make method takes a
// The static make method takes a
template
<
typename
T
>
template
<
typename
T
>
class
ScalarConstant
Op
:
public
ScalarConstantBaseOp
class
ScalarConstant
:
public
ScalarConstantBase
{
{
public
:
public
:
// The ngraph element type
// The ngraph element type
...
@@ -41,8 +41,8 @@ namespace ngraph
...
@@ -41,8 +41,8 @@ namespace ngraph
// The C++ type that holds the element type
// The C++ type that holds the element type
using
ctype
=
typename
T
::
ctype
;
using
ctype
=
typename
T
::
ctype
;
ScalarConstant
Op
(
typename
T
::
ctype
value
)
ScalarConstant
(
typename
T
::
ctype
value
)
:
ScalarConstantBase
Op
(
std
::
make_shared
<
TensorViewType
>
(
T
::
type
,
Shape
{}))
:
ScalarConstantBase
(
std
::
make_shared
<
TensorViewType
>
(
T
::
element_type
()
,
Shape
{}))
,
m_value
(
value
)
,
m_value
(
value
)
{
{
}
}
...
@@ -54,20 +54,20 @@ namespace ngraph
...
@@ -54,20 +54,20 @@ namespace ngraph
// Make a constant from any value that can be converted to the C++ type we use
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
// to represent the values.
template
<
typename
U
>
template
<
typename
U
>
static
std
::
shared_ptr
<
ScalarConstant
Op
<
T
>>
make
(
U
value
)
static
std
::
shared_ptr
<
ScalarConstant
<
T
>>
make
(
U
value
)
{
{
return
std
::
make_shared
<
ScalarConstant
Op
<
T
>>
(
value
);
return
std
::
make_shared
<
ScalarConstant
<
T
>>
(
value
);
}
}
protected
:
protected
:
typename
T
::
ctype
m_value
;
typename
T
::
ctype
m_value
;
};
};
using
FloatScalarConstant
Op
=
ScalarConstantOp
<
element
::
Float
>
;
using
FloatScalarConstant
=
ScalarConstant
<
element
::
Float
>
;
using
Int8ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int8
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int32ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int32
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int64ScalarConstant
Op
=
ScalarConstantOp
<
element
::
Int64
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
using
UInt8ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt8
>
;
using
UInt8ScalarConstant
=
ScalarConstant
<
element
::
UInt8
>
;
using
UInt32ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt32
>
;
using
UInt32ScalarConstant
=
ScalarConstant
<
element
::
UInt32
>
;
using
UInt64ScalarConstant
Op
=
ScalarConstantOp
<
element
::
UInt64
>
;
using
UInt64ScalarConstant
=
ScalarConstant
<
element
::
UInt64
>
;
}
}
src/ops/constant.cpp
View file @
f106a582
...
@@ -16,4 +16,4 @@
...
@@ -16,4 +16,4 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
void
ScalarConstantBase
Op
::
propagate_types
()
{}
void
ScalarConstantBase
::
propagate_types
()
{}
src/types/element_type.cpp
View file @
f106a582
...
@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
...
@@ -48,28 +48,3 @@ size_t ngraph::element::Type::size() const
{
{
return
std
::
ceil
((
float
)
m_bitwidth
/
8.0
);
return
std
::
ceil
((
float
)
m_bitwidth
/
8.0
);
}
}
namespace
{
const
element
::
Float
s_float32_t
=
element
::
Float
{
"float"
};
const
element
::
Int8
s_int8_t
=
element
::
Int8
{
"int8_t"
};
const
element
::
Int32
s_int32_t
=
element
::
Int32
{
"int32_t"
};
const
element
::
Int64
s_int64_t
=
element
::
Int64
{
"int64_t"
};
const
element
::
UInt8
s_uint8_t
=
element
::
UInt8
{
"uint8_t"
};
const
element
::
UInt32
s_uint32_t
=
element
::
UInt32
{
"uint32_t"
};
const
element
::
UInt64
s_uint64_t
=
element
::
UInt64
{
"uint64_t"
};
}
template
<>
const
element
::
TraitedType
<
float
>&
element
::
TraitedType
<
float
>::
type
=
s_float32_t
;
template
<>
const
element
::
TraitedType
<
int8_t
>&
element
::
TraitedType
<
int8_t
>::
type
=
s_int8_t
;
template
<>
const
element
::
TraitedType
<
int32_t
>&
element
::
TraitedType
<
int32_t
>::
type
=
s_int32_t
;
template
<>
const
element
::
TraitedType
<
int64_t
>&
element
::
TraitedType
<
int64_t
>::
type
=
s_int64_t
;
template
<>
const
element
::
TraitedType
<
uint8_t
>&
element
::
TraitedType
<
uint8_t
>::
type
=
s_uint8_t
;
template
<>
const
element
::
TraitedType
<
uint32_t
>&
element
::
TraitedType
<
uint32_t
>::
type
=
s_uint32_t
;
template
<>
const
element
::
TraitedType
<
uint64_t
>&
element
::
TraitedType
<
uint64_t
>::
type
=
s_uint64_t
;
\ No newline at end of file
test/build_graph.cpp
View file @
f106a582
...
@@ -36,10 +36,10 @@ std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_typ
...
@@ -36,10 +36,10 @@ std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_typ
TEST
(
build_graph
,
build_simple
)
TEST
(
build_graph
,
build_simple
)
{
{
// Function with 4 parameters
// Function with 4 parameters
auto
arg0
=
myfun
<
Parameter
>
(
element
::
Float
::
type
,
Shape
{
7
,
3
});
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
7
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
Shape
{
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
type
,
Shape
{
32
,
7
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
32
,
7
});
auto
arg3
=
op
::
parameter
(
element
::
Float
::
type
,
Shape
{
32
,
7
});
auto
arg3
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
Shape
{
32
,
7
});
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
broadcast_1
=
op
::
broadcast
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
b1
=
myfun
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
b1
=
myfun
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
BroadcastOp
::
Axes
{
0
});
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
auto
dot
=
op
::
dot
(
arg2
,
arg0
);
...
@@ -56,7 +56,7 @@ TEST(build_graph, build_simple)
...
@@ -56,7 +56,7 @@ TEST(build_graph, build_simple)
TEST
(
build_graph
,
as_type
)
TEST
(
build_graph
,
as_type
)
{
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType
::
ptr
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
type
,
Shape
{
2
,
3
,
5
});
ValueType
::
ptr
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
()
,
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
...
@@ -73,14 +73,14 @@ TEST(build_graph, as_type)
...
@@ -73,14 +73,14 @@ TEST(build_graph, as_type)
// Check node comparisons
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
TEST
(
build_graph
,
node_comparison
)
{
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
,
3
});
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{
32
,
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
type
,
{
3
});
auto
arg1
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{
3
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
type
,
{
32
});
auto
arg2
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{
32
});
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
parg
=
op
::
parameter
(
element
::
Float
::
type
,
{});
auto
parg
=
op
::
parameter
(
element
::
Float
::
element_type
()
,
{});
auto
pattern_dot
=
op
::
dot
(
parg
,
parg
);
auto
pattern_dot
=
op
::
dot
(
parg
,
parg
);
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
// TODO This passes because typeid is not behaving as documented.
// TODO This passes because typeid is not behaving as documented.
...
@@ -91,8 +91,8 @@ TEST(build_graph, node_comparison)
...
@@ -91,8 +91,8 @@ TEST(build_graph, node_comparison)
TEST
(
build_graph
,
literal
)
TEST
(
build_graph
,
literal
)
{
{
// float scalar from a float
// float scalar from a float
auto
float0
=
FloatScalarConstant
Op
::
make
(
3.0
);
auto
float0
=
FloatScalarConstant
::
make
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
type
,
Shape
{});
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
()
,
Shape
{});
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
float0
->
value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float0
->
type
(),
float_scalar_type
);
auto
d
=
op
::
dot
(
float0
,
float0
);
auto
d
=
op
::
dot
(
float0
,
float0
);
...
@@ -100,12 +100,12 @@ TEST(build_graph, literal)
...
@@ -100,12 +100,12 @@ TEST(build_graph, literal)
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
ASSERT_EQ
(
d
->
arguments
().
at
(
1
),
float0
);
// float scalar from an int
// float scalar from an int
auto
float1
=
FloatScalarConstant
Op
::
make
(
3
);
auto
float1
=
FloatScalarConstant
::
make
(
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
float1
->
value
(),
3
);
ASSERT_EQ
(
*
float1
->
type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float1
->
type
(),
float_scalar_type
);
auto
int32_0
=
Int32ScalarConstant
Op
::
make
(
3.0
);
auto
int32_0
=
Int32ScalarConstant
::
make
(
3.0
);
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
type
,
Shape
{});
auto
int32_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Int32
::
element_type
()
,
Shape
{});
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
int32_0
->
value
(),
3
);
ASSERT_EQ
(
*
int32_0
->
type
(),
int32_scalar_type
);
ASSERT_EQ
(
*
int32_0
->
type
(),
int32_scalar_type
);
ASSERT_NE
(
*
int32_0
->
type
(),
float_scalar_type
);
ASSERT_NE
(
*
int32_0
->
type
(),
float_scalar_type
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment