Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
33b51160
Commit
33b51160
authored
Nov 30, 2017
by
Adam Procter
Browse files
Options
Browse Files
Download
Plain Diff
Merge remote-tracking branch 'origin' into aprocter/de-eigenize
parents
bed290db
e6cc7d8b
Show whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
957 additions
and
55 deletions
+957
-55
CMakeLists.txt
src/ngraph/CMakeLists.txt
+15
-1
function.hpp
src/ngraph/function.hpp
+1
-0
json.hpp
src/ngraph/json.hpp
+2
-0
node.cpp
src/ngraph/node.cpp
+5
-0
node.hpp
src/ngraph/node.hpp
+2
-0
function_call.hpp
src/ngraph/ops/function_call.hpp
+2
-2
get_tuple_element.hpp
src/ngraph/ops/get_tuple_element.hpp
+1
-3
reduce.hpp
src/ngraph/ops/reduce.hpp
+1
-4
tuple.hpp
src/ngraph/ops/tuple.hpp
+1
-1
manager.cpp
src/ngraph/pass/manager.cpp
+2
-21
emitter.cpp
src/ngraph/runtime/cpu/emitter.cpp
+1
-1
external_function.cpp
src/ngraph/runtime/cpu/external_function.cpp
+3
-3
external_function.cpp
src/ngraph/runtime/ngvm/external_function.cpp
+1
-1
serializer.cpp
src/ngraph/serializer.cpp
+655
-0
serializer.hpp
src/ngraph/serializer.hpp
+28
-0
element_type.cpp
src/ngraph/types/element_type.cpp
+24
-3
element_type.hpp
src/ngraph/types/element_type.hpp
+8
-11
util.cpp
src/ngraph/util.cpp
+28
-1
util.hpp
src/ngraph/util.hpp
+3
-1
backend_test.in.cpp
test/backend_test.in.cpp
+2
-1
copy.cpp
test/copy.cpp
+1
-1
element_type.cpp
test/element_type.cpp
+49
-0
serialize.cpp
test/serialize.cpp
+84
-0
util.cpp
test/util.cpp
+38
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
33b51160
...
...
@@ -89,15 +89,29 @@ set (SRC
runtime/tensor_view.cpp
runtime/tuple.cpp
runtime/utils.cpp
serializer.cpp
shape.cpp
types/element_type.cpp
types/type.cpp
util.cpp
)
message
(
STATUS
${
CMAKE_CURRENT_SOURCE_DIR
}
/ops
)
file
(
GLOB_RECURSE OPS
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/ops"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/ops/*.hpp"
)
foreach
(
OP
${
OPS
}
)
file
(
STRINGS
${
OP
}
OP_CLASS REGEX
"class [A-Za-z0-9_]+ :"
)
foreach
(
LINE
${
OP_CLASS
}
)
string
(
REGEX REPLACE
".*class ([A-Za-z0-9_]+) : public ([A-Za-z0-9_]+).*"
"
\\
1:
\\
2"
CLASS_FOUND
${
LINE
}
)
set
(
OP_CLASS_LIST
${
OP_CLASS_LIST
}
${
CLASS_FOUND
}
)
endforeach
(
LINE
${
OP_CLASS
}
)
endforeach
()
message
(
STATUS
"
${
CMAKE_CURRENT_BINARY_DIR
}
/ops_list.txt"
)
string
(
REPLACE
";"
"
\n
"
OP_CLASS_LINES
"
${
OP_CLASS_LIST
}
"
)
file
(
WRITE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/ops_list.txt"
"
${
OP_CLASS_LINES
}
"
)
# find_program (GRAPHVIZ dot)
# message (STATUS "graphviz '${GRAPHVIZ}'")
find_package
(
Graphviz
)
find_package
(
Graphviz
QUIET
)
if
(
GRAPHVIZ_FOUND
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DGRAPHVIZ_FOUND"
)
endif
()
...
...
src/ngraph/function.hpp
View file @
33b51160
...
...
@@ -40,6 +40,7 @@ namespace ngraph
const
std
::
string
&
name
=
""
);
std
::
shared_ptr
<
Node
>
get_result
()
{
return
m_result
;
}
std
::
shared_ptr
<
const
Node
>
get_result
()
const
{
return
m_result
;
}
const
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>&
get_parameters
()
const
{
return
m_parameters
;
...
...
src/ngraph/json.hpp
View file @
33b51160
// clang-format off
#pragma clang diagnostic ignored "-Weverything"
/*
__ _____ _____ _____
__| | __| | | | JSON for Modern C++
...
...
src/ngraph/node.cpp
View file @
33b51160
...
...
@@ -181,6 +181,11 @@ std::shared_ptr<Node> Node::backprop_node(const std::shared_ptr<Node>& x,
return
adjoints_it
->
second
.
get
(
x
);
}
std
::
shared_ptr
<
Function
>
Node
::
get_function
()
const
{
return
nullptr
;
}
namespace
ngraph
{
ostream
&
operator
<<
(
ostream
&
out
,
const
Node
&
node
)
...
...
src/ngraph/node.hpp
View file @
33b51160
...
...
@@ -111,6 +111,8 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
std
::
vector
<
std
::
shared_ptr
<
Node
>>&
new_args
)
const
=
0
;
virtual
std
::
shared_ptr
<
Function
>
get_function
()
const
;
protected
:
std
::
string
m_node_type
;
Nodes
m_arguments
;
...
...
src/ngraph/ops/function_call.hpp
View file @
33b51160
...
...
@@ -45,7 +45,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class
FunctionCall
:
public
ngraph
::
Node
class
FunctionCall
:
public
Node
{
public
:
/// \brief Constructs a function call operation.
...
...
@@ -62,7 +62,7 @@ namespace ngraph
}
/// \return The function to be called.
std
::
shared_ptr
<
Function
>
get_function
()
const
{
return
m_function
;
}
std
::
shared_ptr
<
Function
>
get_function
()
const
override
{
return
m_function
;
}
protected
:
std
::
shared_ptr
<
Function
>
m_function
;
};
...
...
src/ngraph/ops/get_tuple_element.hpp
View file @
33b51160
...
...
@@ -20,8 +20,6 @@ namespace ngraph
{
namespace
op
{
class
Node
;
/// \brief Operation to get an element from a tuple.
///
/// ## Parameters
...
...
@@ -47,7 +45,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class
GetTupleElement
:
public
ngraph
::
Node
class
GetTupleElement
:
public
Node
{
public
:
/// \brief Constructs a get-tuple-element operation.
...
...
src/ngraph/ops/reduce.hpp
View file @
33b51160
...
...
@@ -111,10 +111,7 @@ namespace ngraph
}
/// \return The function to use for reduction.
std
::
shared_ptr
<
Function
>
get_reduction_function
()
const
{
return
m_reduction_function
;
}
std
::
shared_ptr
<
Function
>
get_function
()
const
override
{
return
m_reduction_function
;
}
/// \return The axis positions (0-based) to be eliminated through reduction.
const
AxisSet
&
get_reduction_axes
()
const
{
return
m_reduction_axes
;
}
protected
:
...
...
src/ngraph/ops/tuple.hpp
View file @
33b51160
...
...
@@ -39,7 +39,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class
Tuple
:
public
ngraph
::
Node
class
Tuple
:
public
Node
{
public
:
/// \brief Constructs a tuple construction operation.
...
...
src/ngraph/pass/manager.cpp
View file @
33b51160
...
...
@@ -16,12 +16,12 @@
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -38,30 +38,11 @@ void ngraph::pass::Manager::initialize_default_passes()
{
}
static
void
find_functions
(
shared_ptr
<
Function
>
f
,
set
<
shared_ptr
<
Function
>>&
funcs
)
{
funcs
.
insert
(
f
);
for
(
shared_ptr
<
Node
>
node
:
f
->
get_ops
())
{
shared_ptr
<
op
::
FunctionCall
>
fc
=
dynamic_pointer_cast
<
op
::
FunctionCall
>
(
node
);
if
(
fc
)
{
find_functions
(
fc
->
get_function
(),
funcs
);
}
shared_ptr
<
op
::
Reduce
>
reduce
=
dynamic_pointer_cast
<
op
::
Reduce
>
(
node
);
if
(
reduce
)
{
find_functions
(
reduce
->
get_reduction_function
(),
funcs
);
}
}
}
void
ngraph
::
pass
::
Manager
::
run_passes
(
shared_ptr
<
Function
>
func
)
{
// find all functions
set
<
shared_ptr
<
Function
>>
tfs
;
find_functions
(
func
,
tfs
);
traverse_functions
(
func
,
[
&
](
shared_ptr
<
Function
>
f
)
{
tfs
.
insert
(
f
);
}
);
get_state
().
set_functions
(
tfs
);
vector
<
shared_ptr
<
Function
>>
fs
;
...
...
src/ngraph/runtime/cpu/emitter.cpp
View file @
33b51160
...
...
@@ -1001,7 +1001,7 @@ void Emitter::EmitReduce(const ngraph::Node* n,
const
std
::
vector
<
TensorViewInfo
>&
outputs
)
{
auto
reduce
=
static_cast
<
const
op
::
Reduce
*>
(
n
);
auto
reduction_function
=
reduce
->
get_
reduction_
function
();
auto
reduction_function
=
reduce
->
get_function
();
auto
reductee_type
=
reduce
->
get_arguments
().
at
(
0
)
->
get_value_type
();
auto
reductee_tensor_view_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
reductee_type
);
...
...
src/ngraph/runtime/cpu/external_function.cpp
View file @
33b51160
...
...
@@ -248,9 +248,9 @@ using namespace ngraph::runtime::cpu::eigen;
{
for
(
descriptor
::
Tensor
*
tensor
:
node
->
liveness_new_list
)
{
TU
<<
tensor
->
get_element_type
()
<<
"* "
<<
tensor
->
get_name
()
<<
" = ("
<<
tensor
->
get_element_type
()
<<
"*)(memory_handler.get_ptr("
<<
tensor
->
get_pool_offset
()
<<
"));
\n
"
;
TU
<<
tensor
->
get_element_type
()
.
c_type_string
()
<<
"* "
<<
tensor
->
get_name
()
<<
" = ("
<<
tensor
->
get_element_type
().
c_type_string
()
<<
"*)(memory_handler.get_ptr("
<<
tensor
->
get_pool_offset
()
<<
"));
\n
"
;
}
}
TU
<<
"
\n
"
;
...
...
src/ngraph/runtime/ngvm/external_function.cpp
View file @
33b51160
...
...
@@ -662,7 +662,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_TO_OP_MAP
(
op
::
Reduce
)
{
auto
reduce
=
static_cast
<
const
op
::
Reduce
*>
(
n
);
auto
reduction_function
=
reduce
->
get_
reduction_
function
();
auto
reduction_function
=
reduce
->
get_function
();
std
::
shared_ptr
<
ExternalFunction
>
external
;
...
...
src/ngraph/serializer.cpp
0 → 100644
View file @
33b51160
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/serializer.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/acos.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/cos.hpp"
#include "ngraph/ops/cosh.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/floor.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/greater_eq.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/util.hpp"
using
namespace
ngraph
;
using
namespace
std
;
using
json
=
nlohmann
::
json
;
std
::
shared_ptr
<
ngraph
::
Function
>
read_function
(
const
json
&
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Function
>>&
);
json
write
(
const
ngraph
::
Function
&
);
json
write
(
const
ngraph
::
Node
&
);
json
write
(
const
ngraph
::
element
::
Type
&
);
// This stupidity is caused by the fact that we do not pass element types
// by value but by reference even though they can be compared. There is no reason to pass
// them by reference EVERYWERE but here we are...
const
element
::
Type
&
to_ref
(
const
element
::
Type
&
t
)
{
if
(
t
==
element
::
boolean
)
{
return
element
::
boolean
;
}
if
(
t
==
element
::
f32
)
{
return
element
::
f32
;
}
if
(
t
==
element
::
f64
)
{
return
element
::
f64
;
}
if
(
t
==
element
::
i8
)
{
return
element
::
i8
;
}
if
(
t
==
element
::
i16
)
{
return
element
::
i16
;
}
if
(
t
==
element
::
i32
)
{
return
element
::
i32
;
}
if
(
t
==
element
::
i64
)
{
return
element
::
i64
;
}
if
(
t
==
element
::
u8
)
{
return
element
::
u8
;
}
if
(
t
==
element
::
u16
)
{
return
element
::
u16
;
}
if
(
t
==
element
::
u32
)
{
return
element
::
u32
;
}
if
(
t
==
element
::
u64
)
{
return
element
::
u64
;
}
throw
runtime_error
(
"type not valid"
);
}
static
json
write_element_type
(
const
ngraph
::
element
::
Type
&
n
)
{
json
j
;
j
[
"bitwidth"
]
=
n
.
bitwidth
();
j
[
"is_real"
]
=
n
.
is_real
();
j
[
"is_signed"
]
=
n
.
is_signed
();
j
[
"c_type_string"
]
=
n
.
c_type_string
();
return
j
;
}
static
const
element
::
Type
&
read_element_type
(
const
json
&
j
)
{
size_t
bitwidth
=
j
.
at
(
"bitwidth"
).
get
<
size_t
>
();
bool
is_real
=
j
.
at
(
"is_real"
).
get
<
bool
>
();
bool
is_signed
=
j
.
at
(
"is_signed"
).
get
<
bool
>
();
string
c_type_string
=
j
.
at
(
"c_type_string"
).
get
<
string
>
();
return
to_ref
(
element
::
Type
(
bitwidth
,
is_real
,
is_signed
,
c_type_string
));
}
string
ngraph
::
serialize
(
shared_ptr
<
ngraph
::
Function
>
func
)
{
json
j
;
vector
<
json
>
functions
;
traverse_functions
(
func
,
[
&
](
shared_ptr
<
ngraph
::
Function
>
f
)
{
functions
.
push_back
(
write
(
*
f
));
});
for
(
auto
it
=
functions
.
rbegin
();
it
!=
functions
.
rend
();
it
++
)
{
j
.
push_back
(
*
it
);
}
return
j
.
dump
();
}
shared_ptr
<
ngraph
::
Function
>
ngraph
::
deserialize
(
istream
&
in
)
{
json
js
=
json
::
array
();
shared_ptr
<
Function
>
rc
;
in
>>
js
;
unordered_map
<
string
,
shared_ptr
<
Function
>>
function_map
;
for
(
json
func
:
js
)
{
shared_ptr
<
Function
>
f
=
read_function
(
func
,
function_map
);
if
(
rc
==
nullptr
)
{
rc
=
f
;
}
}
return
rc
;
}
json
write
(
const
Function
&
f
)
{
json
function
;
function
[
"name"
]
=
f
.
get_name
();
function
[
"result_type"
]
=
write_element_type
(
f
.
get_result_type
()
->
get_element_type
());
function
[
"result_shape"
]
=
f
.
get_result_type
()
->
get_shape
();
for
(
auto
param
:
f
.
get_parameters
())
{
function
[
"parameters"
].
push_back
(
param
->
get_name
());
}
function
[
"result"
].
push_back
(
f
.
get_result
()
->
get_name
());
list
<
shared_ptr
<
Node
>>
result_list
;
{
deque
<
Node
*>
independent_nodes
;
unordered_map
<
const
Node
*
,
size_t
>
node_depencency_count
;
unordered_map
<
Node
*
,
shared_ptr
<
Node
>>
node_map
;
traverse_nodes
(
const_cast
<
Function
*>
(
&
f
),
[
&
](
shared_ptr
<
Node
>
node
)
{
node_map
[
node
.
get
()]
=
node
;
node_depencency_count
[
node
.
get
()]
=
node
->
get_arguments
().
size
();
if
(
node
->
get_arguments
().
size
()
==
0
)
{
independent_nodes
.
push_back
(
node
.
get
());
}
});
while
(
independent_nodes
.
size
()
>
0
)
{
auto
independent_node
=
independent_nodes
.
front
();
result_list
.
push_back
(
node_map
[
independent_node
]);
independent_nodes
.
pop_front
();
for
(
auto
user
:
independent_node
->
users
())
{
node_depencency_count
[
user
]
-=
1
;
size_t
count
=
node_depencency_count
[
user
];
if
(
count
==
0
)
{
independent_nodes
.
push_back
(
user
);
}
}
}
}
json
nodes
;
for
(
shared_ptr
<
Node
>
node
:
result_list
)
{
nodes
.
push_back
(
write
(
*
node
));
}
function
[
"ops"
]
=
nodes
;
return
function
;
}
shared_ptr
<
ngraph
::
Function
>
read_function
(
const
json
&
func_js
,
unordered_map
<
string
,
shared_ptr
<
Function
>>&
function_map
)
{
shared_ptr
<
ngraph
::
Function
>
rc
;
string
func_name
=
func_js
.
at
(
"name"
).
get
<
string
>
();
vector
<
string
>
func_result
=
func_js
.
at
(
"result"
).
get
<
vector
<
string
>>
();
vector
<
string
>
func_parameters
=
func_js
.
at
(
"parameters"
).
get
<
vector
<
string
>>
();
const
element
::
Type
&
result_type
=
read_element_type
(
func_js
.
at
(
"result_type"
));
vector
<
size_t
>
result_shape
=
func_js
.
at
(
"result_shape"
).
get
<
vector
<
size_t
>>
();
unordered_map
<
string
,
shared_ptr
<
Node
>>
node_map
;
for
(
json
node_js
:
func_js
.
at
(
"ops"
))
{
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
const
element
::
Type
&
node_etype
=
read_element_type
(
node_js
.
at
(
"element_type"
));
vector
<
string
>
node_inputs
=
node_js
.
at
(
"inputs"
).
get
<
vector
<
string
>>
();
vector
<
string
>
node_outputs
=
node_js
.
at
(
"outputs"
).
get
<
vector
<
string
>>
();
shared_ptr
<
Node
>
node
;
shared_ptr
<
Function
>
function_ptr
=
nullptr
;
vector
<
shared_ptr
<
Node
>>
args
;
for
(
const
string
&
name
:
node_inputs
)
{
args
.
push_back
(
node_map
.
at
(
name
));
}
vector
<
string
>
known_nodes
;
for
(
auto
x
:
node_map
)
{
known_nodes
.
push_back
(
x
.
first
);
}
if
(
node_op
==
"Abs"
)
{
node
=
make_shared
<
op
::
Abs
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Acos"
)
{
node
=
make_shared
<
op
::
Acos
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Add"
)
{
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Asin"
)
{
node
=
make_shared
<
op
::
Asin
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Atan"
)
{
node
=
make_shared
<
op
::
Atan
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Broadcast"
)
{
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
();
node
=
make_shared
<
op
::
Broadcast
>
(
args
[
0
],
shape
,
axes
);
}
else
if
(
node_op
==
"Ceiling"
)
{
node
=
make_shared
<
op
::
Ceiling
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Concat"
)
{
auto
axis
=
node_js
.
at
(
"axis"
).
get
<
size_t
>
();
node
=
make_shared
<
op
::
Concat
>
(
args
,
axis
);
}
else
if
(
node_op
==
"Constant"
)
{
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
value
=
node_js
.
at
(
"value"
).
get
<
vector
<
string
>>
();
node
=
make_shared
<
op
::
Constant
>
(
node_etype
,
shape
,
value
);
}
else
if
(
node_op
==
"Convert"
)
{
auto
target_type
=
read_element_type
(
node_js
.
at
(
"target_type"
));
node
=
make_shared
<
op
::
Convert
>
(
args
[
0
],
target_type
);
}
else
if
(
node_op
==
"Cos"
)
{
node
=
make_shared
<
op
::
Cos
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Cosh"
)
{
node
=
make_shared
<
op
::
Cosh
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Divide"
)
{
node
=
make_shared
<
op
::
Divide
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Dot"
)
{
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Equal"
)
{
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Exp"
)
{
node
=
make_shared
<
op
::
Exp
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Floor"
)
{
node
=
make_shared
<
op
::
Floor
>
(
args
[
0
]);
}
else
if
(
node_op
==
"FunctionCall"
)
{
string
function_name
=
node_js
.
at
(
"function"
).
get
<
string
>
();
shared_ptr
<
Function
>
f_ptr
=
function_map
.
at
(
function_name
);
node
=
make_shared
<
op
::
FunctionCall
>
(
f_ptr
,
args
);
}
// else if (node_op == "GetTupleElement")
// {
// node = make_shared<op::GetTupleElement>(args[0]);
// }
else
if
(
node_op
==
"Greater"
)
{
node
=
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"GreaterEq"
)
{
node
=
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Less"
)
{
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"LessEq"
)
{
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Log"
)
{
node
=
make_shared
<
op
::
Log
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Maximum"
)
{
node
=
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Minimum"
)
{
node
=
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Multiply"
)
{
node
=
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Negative"
)
{
node
=
make_shared
<
op
::
Negative
>
(
args
[
0
]);
}
else
if
(
node_op
==
"NotEqual"
)
{
node
=
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Parameter"
)
{
auto
shape
=
node_js
.
at
(
"shape"
);
node
=
make_shared
<
op
::
Parameter
>
(
node_etype
,
shape
);
}
else
if
(
node_op
==
"Power"
)
{
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Reduce"
)
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
();
node
=
make_shared
<
op
::
Reduce
>
(
args
[
0
],
args
[
1
],
function_ptr
,
reduction_axes
);
}
else
if
(
node_op
==
"Remainder"
)
{
node
=
make_shared
<
op
::
Remainder
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Reshape"
)
{
auto
input_order
=
node_js
.
at
(
"input_order"
).
get
<
vector
<
size_t
>>
();
auto
output_shape
=
node_js
.
at
(
"output_shape"
).
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
Reshape
>
(
args
[
0
],
input_order
,
output_shape
);
}
else
if
(
node_op
==
"Select"
)
{
node
=
make_shared
<
op
::
Select
>
(
args
[
0
],
args
[
1
],
args
[
2
]);
}
else
if
(
node_op
==
"Sign"
)
{
node
=
make_shared
<
op
::
Sign
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Sin"
)
{
node
=
make_shared
<
op
::
Sin
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Sinh"
)
{
node
=
make_shared
<
op
::
Sinh
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Slice"
)
{
auto
lower_bounds
=
node_js
.
at
(
"lower_bounds"
).
get
<
vector
<
size_t
>>
();
auto
upper_bounds
=
node_js
.
at
(
"upper_bounds"
).
get
<
vector
<
size_t
>>
();
auto
step
=
node_js
.
at
(
"step"
).
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
Slice
>
(
args
[
0
],
lower_bounds
,
upper_bounds
,
step
);
}
else
if
(
node_op
==
"Subtract"
)
{
node
=
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
]);
}
else
if
(
node_op
==
"Sum"
)
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
();
node
=
make_shared
<
op
::
Sum
>
(
args
[
0
],
reduction_axes
);
}
else
if
(
node_op
==
"Tan"
)
{
node
=
make_shared
<
op
::
Tan
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Tanh"
)
{
node
=
make_shared
<
op
::
Tanh
>
(
args
[
0
]);
}
else
if
(
node_op
==
"Tuple"
)
{
node
=
make_shared
<
op
::
Tuple
>
(
args
);
}
else
{
stringstream
ss
;
ss
<<
"unsupported op "
<<
node_op
;
throw
runtime_error
(
ss
.
str
());
}
node_map
[
node_name
]
=
node
;
}
auto
result
=
node_map
.
at
(
func_result
[
0
]);
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
for
(
auto
param_name
:
func_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
node_map
.
at
(
param_name
)));
}
auto
rt
=
make_shared
<
TensorViewType
>
(
result_type
,
result_shape
);
rc
=
make_shared
<
Function
>
(
result
,
rt
,
params
,
func_name
);
function_map
[
func_name
]
=
rc
;
return
rc
;
}
json
write
(
const
Node
&
n
)
{
json
node
;
node
[
"name"
]
=
n
.
get_name
();
node
[
"op"
]
=
n
.
description
();
node
[
"element_type"
]
=
write_element_type
(
n
.
get_element_type
());
json
inputs
=
json
::
array
();
json
outputs
=
json
::
array
();
for
(
const
descriptor
::
Input
&
input
:
n
.
get_inputs
())
{
inputs
.
push_back
(
input
.
get_output
().
get_node
()
->
get_name
());
}
for
(
const
descriptor
::
Output
&
output
:
n
.
get_outputs
())
{
outputs
.
push_back
(
output
.
get_node
()
->
get_name
());
}
node
[
"inputs"
]
=
inputs
;
node
[
"outputs"
]
=
outputs
;
string
node_op
=
n
.
description
();
if
(
node_op
==
"Abs"
)
{
}
else
if
(
node_op
==
"Acos"
)
{
}
else
if
(
node_op
==
"Add"
)
{
}
else
if
(
node_op
==
"Asin"
)
{
}
else
if
(
node_op
==
"Atan"
)
{
}
else
if
(
node_op
==
"Broadcast"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Broadcast
*>
(
&
n
);
node
[
"axes"
]
=
tmp
->
get_broadcast_axes
();
node
[
"shape"
]
=
tmp
->
get_broadcast_shape
();
}
else
if
(
node_op
==
"Ceiling"
)
{
}
else
if
(
node_op
==
"Concat"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Concat
*>
(
&
n
);
node
[
"axis"
]
=
tmp
->
get_concatenation_axis
();
}
else
if
(
node_op
==
"Constant"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Constant
*>
(
&
n
);
node
[
"value"
]
=
tmp
->
get_value_strings
();
node
[
"shape"
]
=
tmp
->
get_shape
();
}
else
if
(
node_op
==
"Convert"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Convert
*>
(
&
n
);
node
[
"target_type"
]
=
write_element_type
(
tmp
->
get_convert_element_type
());
}
else
if
(
node_op
==
"Cos"
)
{
}
else
if
(
node_op
==
"Cosh"
)
{
}
else
if
(
node_op
==
"Divide"
)
{
}
else
if
(
node_op
==
"Dot"
)
{
}
else
if
(
node_op
==
"Equal"
)
{
}
else
if
(
node_op
==
"Exp"
)
{
}
else
if
(
node_op
==
"Floor"
)
{
}
else
if
(
node_op
==
"FunctionCall"
)
{
node
[
"function"
]
=
n
.
get_function
()
->
get_name
();
}
else
if
(
node_op
==
"GetTupleElement"
)
{
}
else
if
(
node_op
==
"Greater"
)
{
}
else
if
(
node_op
==
"GreaterEq"
)
{
}
else
if
(
node_op
==
"Less"
)
{
}
else
if
(
node_op
==
"LessEq"
)
{
}
else
if
(
node_op
==
"Log"
)
{
}
else
if
(
node_op
==
"Maximum"
)
{
}
else
if
(
node_op
==
"Minimum"
)
{
}
else
if
(
node_op
==
"Multiply"
)
{
}
else
if
(
node_op
==
"Negative"
)
{
}
else
if
(
node_op
==
"NotEqual"
)
{
}
else
if
(
node_op
==
"Parameter"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Parameter
*>
(
&
n
);
node
[
"shape"
]
=
tmp
->
get_shape
();
}
else
if
(
node_op
==
"Power"
)
{
}
else
if
(
node_op
==
"Reduce"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Reduce
*>
(
&
n
);
node
[
"function"
]
=
tmp
->
get_function
()
->
get_name
();
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
();
}
else
if
(
node_op
==
"Remainder"
)
{
}
else
if
(
node_op
==
"Reshape"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Reshape
*>
(
&
n
);
node
[
"input_order"
]
=
tmp
->
get_input_order
();
node
[
"output_shape"
]
=
tmp
->
get_output_shape
();
}
else
if
(
node_op
==
"Select"
)
{
}
else
if
(
node_op
==
"Sign"
)
{
}
else
if
(
node_op
==
"Sin"
)
{
}
else
if
(
node_op
==
"Sinh"
)
{
}
else
if
(
node_op
==
"Slice"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Slice
*>
(
&
n
);
node
[
"lower_bounds"
]
=
tmp
->
get_lower_bounds
();
node
[
"upper_bounds"
]
=
tmp
->
get_upper_bounds
();
node
[
"step"
]
=
tmp
->
get_step
();
}
else
if
(
node_op
==
"Subtract"
)
{
}
else
if
(
node_op
==
"Sum"
)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Sum
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
();
}
else
if
(
node_op
==
"Tan"
)
{
}
else
if
(
node_op
==
"Tanh"
)
{
}
else
if
(
node_op
==
"Tuple"
)
{
}
return
node
;
}
src/ngraph/serializer.hpp
0 → 100644
View file @
33b51160
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/json.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
std
::
string
serialize
(
std
::
shared_ptr
<
ngraph
::
Function
>
);
std
::
shared_ptr
<
ngraph
::
Function
>
deserialize
(
std
::
istream
&
);
}
src/ngraph/types/element_type.cpp
View file @
33b51160
...
...
@@ -21,7 +21,7 @@
using
namespace
ngraph
;
const
element
::
Type
element
::
boolean
(
8
,
false
,
false
,
"
bool
"
);
const
element
::
Type
element
::
boolean
(
8
,
false
,
false
,
"
char
"
);
const
element
::
Type
element
::
f32
(
32
,
true
,
true
,
"float"
);
const
element
::
Type
element
::
f64
(
64
,
true
,
true
,
"double"
);
const
element
::
Type
element
::
i8
(
8
,
false
,
true
,
"int8_t"
);
...
...
@@ -39,7 +39,6 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st
,
m_is_signed
{
is_signed
}
,
m_cname
{
cname
}
{
assert
(
m_bitwidth
%
8
==
0
);
}
const
std
::
string
&
element
::
Type
::
c_type_string
()
const
...
...
@@ -53,13 +52,35 @@ bool element::Type::operator==(const element::Type& other) const
m_is_signed
==
other
.
m_is_signed
&&
m_cname
==
other
.
m_cname
;
}
bool
element
::
Type
::
operator
<
(
const
Type
&
other
)
const
{
size_t
v1
=
m_bitwidth
<<
2
;
v1
|=
(
m_is_real
?
2
:
0
);
v1
|=
(
m_is_signed
?
1
:
0
);
size_t
v2
=
other
.
m_bitwidth
<<
2
;
v2
|=
(
other
.
m_is_real
?
2
:
0
);
v2
|=
(
other
.
m_is_signed
?
1
:
0
);
return
v1
<
v2
;
}
size_t
element
::
Type
::
size
()
const
{
return
std
::
ceil
(
static_cast
<
float
>
(
m_bitwidth
)
/
8.0
f
);
}
size_t
element
::
Type
::
hash
()
const
{
size_t
h1
=
std
::
hash
<
size_t
>
{}(
m_bitwidth
);
size_t
h2
=
std
::
hash
<
bool
>
{}(
m_is_real
);
size_t
h3
=
std
::
hash
<
bool
>
{}(
m_is_signed
);
return
h1
^
((
h2
^
(
h3
<<
1
))
<<
1
);
}
std
::
ostream
&
element
::
operator
<<
(
std
::
ostream
&
out
,
const
element
::
Type
&
obj
)
{
out
<<
obj
.
m_cname
;
out
<<
"element::Type("
<<
obj
.
m_bitwidth
<<
", "
<<
obj
.
m_is_real
<<
", "
<<
obj
.
m_is_signed
<<
")"
;
return
out
;
}
src/ngraph/types/element_type.hpp
View file @
33b51160
...
...
@@ -47,23 +47,20 @@ namespace ngraph
class
Type
{
Type
(
const
Type
&
)
=
delete
;
Type
&
operator
=
(
const
Type
&
)
=
delete
;
public
:
virtual
~
Type
()
{}
Type
()
=
delete
;
Type
(
const
Type
&
)
=
default
;
Type
(
size_t
bitwidth
,
bool
is_real
,
bool
is_signed
,
const
std
::
string
&
cname
);
virtual
~
Type
()
{}
const
std
::
string
&
c_type_string
()
const
;
size_t
size
()
const
;
size_t
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
return
h
(
m_cname
);
}
size_t
hash
()
const
;
bool
is_real
()
const
{
return
m_is_real
;
}
bool
is_signed
()
const
{
return
m_is_signed
;
}
size_t
bitwidth
()
const
{
return
m_bitwidth
;
}
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
bool
operator
<
(
const
Type
&
other
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Type
&
);
private
:
...
...
src/ngraph/util.cpp
View file @
33b51160
...
...
@@ -145,7 +145,6 @@ void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
}
void
ngraph
::
traverse_nodes
(
ngraph
::
Function
*
p
,
std
::
function
<
void
(
shared_ptr
<
Node
>
)
>
f
)
{
std
::
unordered_set
<
shared_ptr
<
Node
>>
instances_seen
;
deque
<
shared_ptr
<
Node
>>
stack
;
...
...
@@ -172,6 +171,34 @@ void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<N
}
}
void
ngraph
::
traverse_functions
(
std
::
shared_ptr
<
ngraph
::
Function
>
p
,
std
::
function
<
void
(
shared_ptr
<
Function
>
)
>
f
)
{
std
::
unordered_set
<
shared_ptr
<
Function
>>
instances_seen
;
deque
<
shared_ptr
<
Function
>>
stack
;
stack
.
push_front
(
p
);
while
(
stack
.
size
()
>
0
)
{
shared_ptr
<
Function
>
func
=
stack
.
front
();
if
(
instances_seen
.
find
(
func
)
==
instances_seen
.
end
())
{
instances_seen
.
insert
(
func
);
f
(
func
);
}
stack
.
pop_front
();
for
(
shared_ptr
<
Node
>
op
:
func
->
get_ops
())
{
shared_ptr
<
Function
>
fp
=
op
->
get_function
();
if
(
fp
)
{
stack
.
push_front
(
fp
);
}
}
}
}
void
ngraph
::
free_nodes
(
shared_ptr
<
Function
>
p
)
{
std
::
deque
<
Node
*>
sorted_list
;
...
...
src/ngraph/util.hpp
View file @
33b51160
...
...
@@ -18,6 +18,7 @@
#include <chrono>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <sstream>
...
...
@@ -239,8 +240,9 @@ namespace ngraph
}
void
traverse_nodes
(
Function
*
p
,
std
::
function
<
void
(
std
::
shared_ptr
<
Node
>
)
>
f
);
void
traverse_nodes
(
std
::
shared_ptr
<
Function
>
p
,
std
::
function
<
void
(
std
::
shared_ptr
<
Node
>
)
>
f
);
void
traverse_functions
(
std
::
shared_ptr
<
Function
>
p
,
std
::
function
<
void
(
std
::
shared_ptr
<
Function
>
)
>
f
);
void
free_nodes
(
std
::
shared_ptr
<
Function
>
);
}
// end namespace ngraph
test/backend_test.in.cpp
View file @
33b51160
...
...
@@ -14,9 +14,10 @@
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include "gtest/gtest.h"
#include <cmath>
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
...
...
test/copy.cpp
View file @
33b51160
...
...
@@ -358,7 +358,7 @@ TEST(copy, reduce)
ASSERT_TRUE
(
nullptr
!=
new_node
);
ASSERT_TRUE
(
new_args
==
new_node
->
get_arguments
());
ASSERT_TRUE
(
f
==
node_cast
->
get_
reduction_
function
());
ASSERT_TRUE
(
f
==
node_cast
->
get_function
());
ASSERT_TRUE
(
axes
==
node_cast
->
get_reduction_axes
());
}
...
...
test/element_type.cpp
View file @
33b51160
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <map>
#include "gtest/gtest.h"
#include "ngraph/types/element_type.hpp"
...
...
@@ -33,3 +35,50 @@ TEST(element_type, from)
EXPECT_EQ
(
element
::
from
<
uint32_t
>
(),
element
::
u32
);
EXPECT_EQ
(
element
::
from
<
uint64_t
>
(),
element
::
u64
);
}
TEST
(
element_type
,
mapable
)
{
std
::
map
<
element
::
Type
,
std
::
string
>
test_map
;
test_map
.
insert
({
element
::
f32
,
"float"
});
}
TEST
(
element_type
,
size
)
{
{
element
::
Type
t1
{
1
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
2
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
3
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
4
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
5
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
6
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
7
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
2
,
false
,
false
,
""
};
EXPECT_EQ
(
1
,
t1
.
size
());
}
{
element
::
Type
t1
{
9
,
false
,
false
,
""
};
EXPECT_EQ
(
2
,
t1
.
size
());
}
}
test/serialize.cpp
View file @
33b51160
...
...
@@ -12,4 +12,88 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <fstream>
#include <sstream>
#include "gtest/gtest.h"
#include "ngraph/json.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
template
<
typename
T
>
static
void
copy_data
(
shared_ptr
<
runtime
::
TensorView
>
tv
,
const
vector
<
T
>&
data
)
{
size_t
data_size
=
data
.
size
()
*
sizeof
(
T
);
tv
->
write
(
data
.
data
(),
0
,
data_size
);
}
TEST
(
serialize
,
main
)
{
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
},
"f"
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
},
"g"
);
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_h
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
h
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
})
+
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
}),
rt_h
,
op
::
Parameters
{
X1
,
Y1
,
Z1
},
"h"
);
string
js
=
serialize
(
h
);
{
ofstream
f
(
"serialize_function.js"
);
f
<<
js
;
}
istringstream
in
(
js
);
shared_ptr
<
Function
>
sfunc
=
deserialize
(
in
);
// Now call g on some test vectors.
auto
manager
=
runtime
::
Manager
::
get
(
"CPU"
);
auto
external
=
manager
->
compile
(
sfunc
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
auto
x
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape
);
copy_data
(
x
,
vector
<
float
>
{
1
,
2
,
3
,
4
});
auto
y
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape
);
copy_data
(
y
,
vector
<
float
>
{
5
,
6
,
7
,
8
});
auto
z
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape
);
copy_data
(
z
,
vector
<
float
>
{
9
,
10
,
11
,
12
});
auto
result
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape
);
cf
->
call
({
x
,
y
,
z
},
{
result
});
EXPECT_EQ
((
vector
<
float
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
<
float
>
());
cf
->
call
({
y
,
x
,
z
},
{
result
});
EXPECT_EQ
((
vector
<
float
>
{
54
,
80
,
110
,
144
}),
result
->
get_vector
<
float
>
());
cf
->
call
({
x
,
z
,
y
},
{
result
});
EXPECT_EQ
((
vector
<
float
>
{
50
,
72
,
98
,
128
}),
result
->
get_vector
<
float
>
());
}
test/util.cpp
View file @
33b51160
...
...
@@ -18,6 +18,7 @@
#include "gtest/gtest.h"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
...
...
@@ -202,3 +203,40 @@ TEST(util, all_close)
EXPECT_FALSE
(
ngraph
::
test
::
all_close
<
float
>
(
c
,
a
,
.05
f
,
0
));
EXPECT_TRUE
(
ngraph
::
test
::
all_close
<
float
>
(
c
,
a
,
.11
f
,
0
));
}
TEST
(
util
,
traverse_functions
)
{
// First create "f(A,B,C) = (A+B)*C".
auto
shape
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_f
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
f
=
make_shared
<
Function
>
((
A
+
B
)
*
C
,
rt_f
,
op
::
Parameters
{
A
,
B
,
C
},
"f"
);
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_g
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
g
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
})
+
make_shared
<
op
::
FunctionCall
>
(
f
,
Nodes
{
X
,
Y
,
Z
}),
rt_g
,
op
::
Parameters
{
X
,
Y
,
Z
},
"g"
);
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto
X1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Y1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
Z1
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
rt_h
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape
);
auto
h
=
make_shared
<
Function
>
(
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
})
+
make_shared
<
op
::
FunctionCall
>
(
g
,
Nodes
{
X1
,
Y1
,
Z1
}),
rt_h
,
op
::
Parameters
{
X1
,
Y1
,
Z1
},
"h"
);
vector
<
Function
*>
functions
;
traverse_functions
(
h
,
[
&
](
shared_ptr
<
Function
>
fp
)
{
functions
.
push_back
(
fp
.
get
());
});
ASSERT_EQ
(
3
,
functions
.
size
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment