Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
fd881acc
Commit
fd881acc
authored
Sep 01, 2017
by
Scott Cyphers
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Review comments.
parent
064fb0fc
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
68 additions
and
79 deletions
+68
-79
element_type.hpp
src/ngraph/element_type.hpp
+10
-10
function.hpp
src/ngraph/function.hpp
+6
-3
node.cpp
src/ngraph/node.cpp
+3
-3
node.hpp
src/ngraph/node.hpp
+25
-2
op.hpp
src/ngraph/op.hpp
+6
-1
constant.hpp
src/ngraph/ops/constant.hpp
+1
-1
type.hpp
src/ngraph/type.hpp
+0
-42
parameter.cpp
src/ops/parameter.cpp
+1
-1
build_graph.cpp
test/build_graph.cpp
+13
-13
op.cpp
test/op.cpp
+2
-2
topological_sort.cpp
test/topological_sort.cpp
+1
-1
No files found.
src/ngraph/element_type.hpp
View file @
fd881acc
...
@@ -49,7 +49,7 @@ namespace ngraph
...
@@ -49,7 +49,7 @@ namespace ngraph
size_t
m_bitwidth
;
size_t
m_bitwidth
;
bool
m_is_float
;
bool
m_is_float
;
bool
m_is_signed
;
bool
m_is_signed
;
const
std
::
string
m_cname
;
const
std
::
string
&
m_cname
;
};
};
// Provides a compile-time name for a C++ type.
// Provides a compile-time name for a C++ type.
...
@@ -62,7 +62,7 @@ namespace ngraph
...
@@ -62,7 +62,7 @@ namespace ngraph
}
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_T
TN
(T) \
#define NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(T) \
template <> \
template <> \
constexpr const char* traited_type_name<T>() \
constexpr const char* traited_type_name<T>() \
{ \
{ \
...
@@ -95,25 +95,25 @@ namespace ngraph
...
@@ -95,25 +95,25 @@ namespace ngraph
}
}
};
};
NGRAPH_DEFINE_T
TN
(
float
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
float
)
using
Float
=
TraitedType
<
float
>
;
using
Float
32
=
TraitedType
<
float
>
;
NGRAPH_DEFINE_T
TN
(
int8_t
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
int8_t
)
using
Int8
=
TraitedType
<
int8_t
>
;
using
Int8
=
TraitedType
<
int8_t
>
;
NGRAPH_DEFINE_T
TN
(
int32_t
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
int32_t
)
using
Int32
=
TraitedType
<
int32_t
>
;
using
Int32
=
TraitedType
<
int32_t
>
;
NGRAPH_DEFINE_T
TN
(
int64_t
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
int64_t
)
using
Int64
=
TraitedType
<
int64_t
>
;
using
Int64
=
TraitedType
<
int64_t
>
;
NGRAPH_DEFINE_T
TN
(
uint8_t
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
uint8_t
)
using
UInt8
=
TraitedType
<
uint8_t
>
;
using
UInt8
=
TraitedType
<
uint8_t
>
;
NGRAPH_DEFINE_T
TN
(
uint32_t
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
uint32_t
)
using
UInt32
=
TraitedType
<
uint32_t
>
;
using
UInt32
=
TraitedType
<
uint32_t
>
;
NGRAPH_DEFINE_T
TN
(
uint64_t
)
NGRAPH_DEFINE_T
RAITED_TYPE_NAME
(
uint64_t
)
using
UInt64
=
TraitedType
<
uint64_t
>
;
using
UInt64
=
TraitedType
<
uint64_t
>
;
}
}
}
}
src/ngraph/function.hpp
View file @
fd881acc
...
@@ -28,9 +28,12 @@ namespace ngraph
...
@@ -28,9 +28,12 @@ namespace ngraph
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
Function
(
const
std
::
shared_ptr
<
Node
>&
result
,
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>&
parameters
);
std
::
shared_ptr
<
Node
>
result
()
{
return
m_result
;
}
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
std
::
shared_ptr
<
Parameter
>
parameter
(
size_t
i
)
{
return
m_parameters
[
i
];
}
const
std
::
vector
<
std
::
shared_ptr
<
Parameter
>>
get_parameters
()
const
std
::
string
name
()
const
{
return
m_name
;
}
{
return
m_parameters
;
}
std
::
string
get_name
()
const
{
return
m_name
;
}
protected
:
protected
:
std
::
shared_ptr
<
Node
>
m_result
;
std
::
shared_ptr
<
Node
>
m_result
;
...
...
src/ngraph/node.cpp
View file @
fd881acc
...
@@ -18,9 +18,9 @@
...
@@ -18,9 +18,9 @@
size_t
ngraph
::
Node
::
m_next_instance_id
=
0
;
size_t
ngraph
::
Node
::
m_next_instance_id
=
0
;
ngraph
::
Node
::
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
ngraph
::
Node
::
Node
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
)
std
::
shared_ptr
<
ValueType
>
value_
type
)
:
TypedValueMixin
(
type
)
:
m_arguments
(
arguments
)
,
m_
arguments
(
arguments
)
,
m_
value_type
(
value_type
)
,
m_instance_id
(
m_next_instance_id
++
)
,
m_instance_id
(
m_next_instance_id
++
)
{
{
// Add this node as a user of each argument.
// Add this node as a user of each argument.
...
...
src/ngraph/node.hpp
View file @
fd881acc
...
@@ -30,10 +30,19 @@ namespace ngraph
...
@@ -30,10 +30,19 @@ namespace ngraph
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values.
/// view or a (possibly empty) tuple of values.
class
Node
:
public
TypedValueMixin
,
public
std
::
enable_shared_from_this
<
Node
>
class
Node
:
public
std
::
enable_shared_from_this
<
Node
>
{
{
protected
:
protected
:
Node
(
const
Nodes
&
arguments
,
std
::
shared_ptr
<
ValueType
>
type
=
nullptr
);
Node
(
const
Nodes
&
arguments
,
std
::
shared_ptr
<
ValueType
>
value_type
=
nullptr
);
Node
()
:
Node
({},
nullptr
)
{
}
Node
(
std
::
shared_ptr
<
ValueType
>
value_type
)
:
Node
({},
value_type
)
{
}
virtual
~
Node
()
{}
virtual
~
Node
()
{}
...
@@ -61,6 +70,19 @@ namespace ngraph
...
@@ -61,6 +70,19 @@ namespace ngraph
return
typeid
(
*
this
)
==
typeid
(
*
node
.
get
());
return
typeid
(
*
this
)
==
typeid
(
*
node
.
get
());
}
}
std
::
shared_ptr
<
ValueType
>
get_value_type
()
{
return
m_value_type
;
}
const
std
::
shared_ptr
<
ValueType
>
get_value_type
()
const
{
return
m_value_type
;
}
void
set_value_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_value_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
void
set_value_type
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
{
m_value_type
=
value_type
;
}
bool
is_op
()
const
;
bool
is_op
()
const
;
bool
is_parameter
()
const
;
bool
is_parameter
()
const
;
...
@@ -69,6 +91,7 @@ namespace ngraph
...
@@ -69,6 +91,7 @@ namespace ngraph
protected
:
protected
:
Nodes
m_arguments
;
Nodes
m_arguments
;
std
::
shared_ptr
<
ValueType
>
m_value_type
;
std
::
multiset
<
Node
*>
m_users
;
std
::
multiset
<
Node
*>
m_users
;
std
::
string
m_name
;
std
::
string
m_name
;
size_t
m_instance_id
;
size_t
m_instance_id
;
...
...
src/ngraph/op.hpp
View file @
fd881acc
...
@@ -80,7 +80,12 @@ namespace ngraph
...
@@ -80,7 +80,12 @@ namespace ngraph
{
{
public
:
public
:
Op
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
)
Op
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
arguments
)
:
Node
(
arguments
,
nullptr
)
:
Node
(
arguments
)
{
}
Op
()
:
Node
()
{
{
}
}
...
...
src/ngraph/ops/constant.hpp
View file @
fd881acc
...
@@ -63,7 +63,7 @@ namespace ngraph
...
@@ -63,7 +63,7 @@ namespace ngraph
typename
T
::
type
m_value
;
typename
T
::
type
m_value
;
};
};
using
Float
ScalarConstant
=
ScalarConstant
<
element
::
Float
>
;
using
Float
32ScalarConstant
=
ScalarConstant
<
element
::
Float32
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int8ScalarConstant
=
ScalarConstant
<
element
::
Int8
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int32ScalarConstant
=
ScalarConstant
<
element
::
Int32
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
using
Int64ScalarConstant
=
ScalarConstant
<
element
::
Int64
>
;
...
...
src/ngraph/type.hpp
View file @
fd881acc
...
@@ -82,46 +82,4 @@ namespace ngraph
...
@@ -82,46 +82,4 @@ namespace ngraph
protected
:
protected
:
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
m_element_types
;
std
::
vector
<
std
::
shared_ptr
<
ValueType
>>
m_element_types
;
};
};
/**
** Mixin for objects with type information
**/
class
TypedValueMixin
{
public
:
TypedValueMixin
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
=
nullptr
)
:
m_value_type
(
value_type
)
{
}
/**
** Set the type
** /param type The new type
**/
void
set_value_type
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
{
m_value_type
=
value_type
;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void
set_value_type
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
{
m_value_type
=
std
::
make_shared
<
TensorViewType
>
(
element_type
,
shape
);
}
/**
** The type associated with this value.
**/
std
::
shared_ptr
<
ValueType
>
get_value_type
()
{
return
m_value_type
;
}
/**
** The type associated with this value.
**/
const
std
::
shared_ptr
<
ValueType
>
get_value_type
()
const
{
return
m_value_type
;
}
protected
:
std
::
shared_ptr
<
ValueType
>
m_value_type
;
};
}
}
src/ops/parameter.cpp
View file @
fd881acc
...
@@ -20,7 +20,7 @@ using namespace std;
...
@@ -20,7 +20,7 @@ using namespace std;
using
namespace
ngraph
;
using
namespace
ngraph
;
Parameter
::
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
Parameter
::
Parameter
(
const
std
::
shared_ptr
<
ValueType
>&
value_type
)
:
Node
(
{},
value_type
)
:
Node
(
value_type
)
,
m_function
(
nullptr
)
,
m_function
(
nullptr
)
,
m_index
(
0
)
,
m_index
(
0
)
{
{
...
...
test/build_graph.cpp
View file @
fd881acc
...
@@ -23,10 +23,10 @@ using namespace ngraph;
...
@@ -23,10 +23,10 @@ using namespace ngraph;
TEST
(
build_graph
,
build_simple
)
TEST
(
build_graph
,
build_simple
)
{
{
// Function with 4 parameters
// Function with 4 parameters
auto
arg0
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
7
,
3
});
auto
arg0
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
7
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
7
});
auto
arg3
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
32
,
7
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
broadcast_1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
b1
=
node
<
BroadcastOp
>
(
arg3
,
Shape
{
10
,
32
,
7
},
AxisSet
{
0
});
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
auto
dot
=
node
<
DotOp
>
(
arg2
,
arg0
);
...
@@ -35,14 +35,14 @@ TEST(build_graph, build_simple)
...
@@ -35,14 +35,14 @@ TEST(build_graph, build_simple)
auto
cluster_0
=
op
::
function
(
dot
,
{
arg0
,
arg1
,
arg2
,
arg3
});
auto
cluster_0
=
op
::
function
(
dot
,
{
arg0
,
arg1
,
arg2
,
arg3
});
ASSERT_EQ
(
cluster_0
->
result
(),
dot
);
ASSERT_EQ
(
cluster_0
->
get_
result
(),
dot
);
}
}
// Check upcasting from ValueType.
// Check upcasting from ValueType.
TEST
(
build_graph
,
as_type
)
TEST
(
build_graph
,
as_type
)
{
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
auto
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{
2
,
3
,
5
});
auto
tv_vt
=
make_shared
<
TensorViewType
>
(
element
::
Float
32
::
element_type
(),
Shape
{
2
,
3
,
5
});
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
auto
tv_tv
=
dynamic_pointer_cast
<
TensorViewType
>
(
tv_vt
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
ASSERT_EQ
(
tv_vt
,
tv_tv
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
auto
tv_tp
=
dynamic_pointer_cast
<
TupleType
>
(
tv_vt
);
...
@@ -59,14 +59,14 @@ TEST(build_graph, as_type)
...
@@ -59,14 +59,14 @@ TEST(build_graph, as_type)
// Check node comparisons
// Check node comparisons
TEST
(
build_graph
,
node_comparison
)
TEST
(
build_graph
,
node_comparison
)
{
{
auto
arg0
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
,
3
});
auto
arg0
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
32
,
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
3
});
auto
arg1
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
3
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{
32
});
auto
arg2
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{
32
});
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
dot
=
op
::
dot
(
arg0
,
arg1
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
add
=
op
::
add
(
dot
,
arg2
);
auto
parg
=
node
<
Parameter
>
(
element
::
Float
::
element_type
(),
Shape
{});
auto
parg
=
node
<
Parameter
>
(
element
::
Float
32
::
element_type
(),
Shape
{});
auto
pattern_dot
=
node
<
DotOp
>
(
parg
,
parg
);
auto
pattern_dot
=
node
<
DotOp
>
(
parg
,
parg
);
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
ASSERT_TRUE
(
pattern_dot
->
is_same_op_type
(
dot
));
// TODO This passes because typeid is not behaving as documented.
// TODO This passes because typeid is not behaving as documented.
...
@@ -78,8 +78,8 @@ TEST(build_graph, literal)
...
@@ -78,8 +78,8 @@ TEST(build_graph, literal)
{
{
// float scalar from a float
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
//auto float0 = FloatScalarConstant::make(3.0);
auto
float0
=
node
<
FloatScalarConstant
>
(
3.0
);
auto
float0
=
node
<
Float
32
ScalarConstant
>
(
3.0
);
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
::
element_type
(),
Shape
{});
auto
float_scalar_type
=
make_shared
<
TensorViewType
>
(
element
::
Float
32
::
element_type
(),
Shape
{});
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
float0
->
get_value
(),
3.0
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float0
->
get_value_type
(),
float_scalar_type
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
auto
d
=
node
<
DotOp
>
(
float0
,
float0
);
...
@@ -87,7 +87,7 @@ TEST(build_graph, literal)
...
@@ -87,7 +87,7 @@ TEST(build_graph, literal)
ASSERT_EQ
(
d
->
get_arguments
().
at
(
1
),
float0
);
ASSERT_EQ
(
d
->
get_arguments
().
at
(
1
),
float0
);
// float scalar from an int
// float scalar from an int
auto
float1
=
node
<
FloatScalarConstant
>
(
3
);
auto
float1
=
node
<
Float
32
ScalarConstant
>
(
3
);
ASSERT_EQ
(
float1
->
get_value
(),
3
);
ASSERT_EQ
(
float1
->
get_value
(),
3
);
ASSERT_EQ
(
*
float1
->
get_value_type
(),
float_scalar_type
);
ASSERT_EQ
(
*
float1
->
get_value_type
(),
float_scalar_type
);
...
...
test/op.cpp
View file @
fd881acc
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
TEST
(
op
,
is_op
)
TEST
(
op
,
is_op
)
{
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
1
});
auto
arg0
=
op
::
parameter
(
element
::
Float
32
::
element_type
(),
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
ASSERT_NE
(
nullptr
,
arg0
);
EXPECT_TRUE
(
arg0
->
is_parameter
());
EXPECT_TRUE
(
arg0
->
is_parameter
());
EXPECT_FALSE
(
arg0
->
is_op
());
EXPECT_FALSE
(
arg0
->
is_op
());
...
@@ -31,7 +31,7 @@ TEST(op, is_op)
...
@@ -31,7 +31,7 @@ TEST(op, is_op)
TEST
(
op
,
is_parameter
)
TEST
(
op
,
is_parameter
)
{
{
auto
arg0
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
1
});
auto
arg0
=
op
::
parameter
(
element
::
Float
32
::
element_type
(),
{
1
});
ASSERT_NE
(
nullptr
,
arg0
);
ASSERT_NE
(
nullptr
,
arg0
);
auto
t0
=
op
::
add
(
arg0
,
arg0
);
auto
t0
=
op
::
add
(
arg0
,
arg0
);
ASSERT_NE
(
nullptr
,
t0
);
ASSERT_NE
(
nullptr
,
t0
);
...
...
test/topological_sort.cpp
View file @
fd881acc
...
@@ -61,7 +61,7 @@ TEST(topological_sort, basic)
...
@@ -61,7 +61,7 @@ TEST(topological_sort, basic)
vector
<
shared_ptr
<
Parameter
>>
args
;
vector
<
shared_ptr
<
Parameter
>>
args
;
for
(
int
i
=
0
;
i
<
10
;
i
++
)
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
{
auto
arg
=
op
::
parameter
(
element
::
Float
::
element_type
(),
{
1
});
auto
arg
=
op
::
parameter
(
element
::
Float
32
::
element_type
(),
{
1
});
ASSERT_NE
(
nullptr
,
arg
);
ASSERT_NE
(
nullptr
,
arg
);
args
.
push_back
(
arg
);
args
.
push_back
(
arg
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment