Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6f1e728a
Unverified
Commit
6f1e728a
authored
Jun 25, 2019
by
Fenglei Tian
Committed by
GitHub
Jun 25, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into tfl/send_recv_op
parents
9c2230aa
d0a83a35
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
82 changed files
with
586 additions
and
365 deletions
+586
-365
distributed.cpp
src/ngraph/distributed.cpp
+5
-20
distributed.hpp
src/ngraph/distributed.hpp
+5
-24
mlsl.hpp
src/ngraph/distributed/mlsl.hpp
+5
-5
open_mpi.hpp
src/ngraph/distributed/open_mpi.hpp
+5
-5
graph_util.cpp
src/ngraph/graph_util.cpp
+5
-0
node.cpp
src/ngraph/node.cpp
+10
-0
node.hpp
src/ngraph/node.hpp
+11
-6
abs.cpp
src/ngraph/op/abs.cpp
+0
-4
abs.hpp
src/ngraph/op/abs.hpp
+1
-1
acos.cpp
src/ngraph/op/acos.cpp
+0
-4
acos.hpp
src/ngraph/op/acos.hpp
+1
-1
add.cpp
src/ngraph/op/add.cpp
+0
-4
add.hpp
src/ngraph/op/add.hpp
+1
-1
all.cpp
src/ngraph/op/all.cpp
+0
-4
all.hpp
src/ngraph/op/all.hpp
+1
-1
allreduce.cpp
src/ngraph/op/allreduce.cpp
+7
-7
allreduce.hpp
src/ngraph/op/allreduce.hpp
+4
-4
and.cpp
src/ngraph/op/and.cpp
+0
-4
and.hpp
src/ngraph/op/and.hpp
+1
-1
any.cpp
src/ngraph/op/any.cpp
+0
-4
any.hpp
src/ngraph/op/any.hpp
+1
-1
argmax.cpp
src/ngraph/op/argmax.cpp
+0
-4
argmax.hpp
src/ngraph/op/argmax.hpp
+1
-1
argmin.cpp
src/ngraph/op/argmin.cpp
+0
-4
argmin.hpp
src/ngraph/op/argmin.hpp
+1
-1
asin.cpp
src/ngraph/op/asin.cpp
+0
-4
asin.hpp
src/ngraph/op/asin.hpp
+1
-1
atan.cpp
src/ngraph/op/atan.cpp
+0
-4
atan.hpp
src/ngraph/op/atan.hpp
+1
-1
avg_pool.cpp
src/ngraph/op/avg_pool.cpp
+0
-8
avg_pool.hpp
src/ngraph/op/avg_pool.hpp
+2
-2
batch_norm.cpp
src/ngraph/op/batch_norm.cpp
+42
-40
batch_norm.hpp
src/ngraph/op/batch_norm.hpp
+42
-28
broadcast.cpp
src/ngraph/op/broadcast.cpp
+11
-10
broadcast.hpp
src/ngraph/op/broadcast.hpp
+25
-10
broadcast_distributed.cpp
src/ngraph/op/broadcast_distributed.cpp
+9
-2
broadcast_distributed.hpp
src/ngraph/op/broadcast_distributed.hpp
+7
-2
ceiling.cpp
src/ngraph/op/ceiling.cpp
+4
-2
ceiling.hpp
src/ngraph/op/ceiling.hpp
+6
-1
concat.cpp
src/ngraph/op/concat.cpp
+9
-2
concat.hpp
src/ngraph/op/concat.hpp
+17
-1
constant.cpp
src/ngraph/op/constant.cpp
+2
-0
constant.hpp
src/ngraph/op/constant.hpp
+5
-2
convert.cpp
src/ngraph/op/convert.cpp
+4
-2
convert.hpp
src/ngraph/op/convert.hpp
+12
-2
convolution.cpp
src/ngraph/op/convolution.cpp
+27
-21
convolution.hpp
src/ngraph/op/convolution.hpp
+0
-0
cos.cpp
src/ngraph/op/cos.cpp
+4
-2
cos.hpp
src/ngraph/op/cos.hpp
+6
-1
cosh.cpp
src/ngraph/op/cosh.cpp
+4
-2
cosh.hpp
src/ngraph/op/cosh.hpp
+6
-1
dequantize.cpp
src/ngraph/op/dequantize.cpp
+6
-4
dequantize.hpp
src/ngraph/op/dequantize.hpp
+17
-8
divide.cpp
src/ngraph/op/divide.cpp
+9
-8
divide.hpp
src/ngraph/op/divide.hpp
+13
-7
dot.cpp
src/ngraph/op/dot.cpp
+7
-5
dot.hpp
src/ngraph/op/dot.hpp
+17
-3
embedding_lookup.cpp
src/ngraph/op/embedding_lookup.cpp
+2
-0
embedding_lookup.hpp
src/ngraph/op/embedding_lookup.hpp
+7
-2
equal.cpp
src/ngraph/op/equal.cpp
+4
-4
equal.hpp
src/ngraph/op/equal.hpp
+8
-3
erf.cpp
src/ngraph/op/erf.cpp
+4
-2
erf.hpp
src/ngraph/op/erf.hpp
+5
-1
exp.cpp
src/ngraph/op/exp.cpp
+4
-2
exp.hpp
src/ngraph/op/exp.hpp
+6
-1
floor.cpp
src/ngraph/op/floor.cpp
+4
-2
floor.hpp
src/ngraph/op/floor.hpp
+6
-1
split.cpp
src/ngraph/op/fused/split.cpp
+7
-0
reshape.cpp
src/ngraph/op/reshape.cpp
+4
-2
reshape.hpp
src/ngraph/op/reshape.hpp
+11
-3
result.cpp
src/ngraph/op/result.cpp
+4
-2
result.hpp
src/ngraph/op/result.hpp
+6
-1
matcher.cpp
src/ngraph/pattern/matcher.cpp
+15
-4
backend.cpp
src/ngraph/runtime/backend.cpp
+6
-0
backend.hpp
src/ngraph/runtime/backend.hpp
+9
-0
int_backend.cpp
src/ngraph/runtime/interpreter/int_backend.cpp
+13
-0
int_backend.hpp
src/ngraph/runtime/interpreter/int_backend.hpp
+2
-0
serializer.cpp
src/ngraph/serializer.cpp
+18
-0
backend_api.cpp
test/backend_api.cpp
+22
-0
distributed.in.cpp
test/distributed.in.cpp
+10
-10
pattern.cpp
test/pattern.cpp
+27
-0
backprop_derivative.hpp
test/util/autodiff/backprop_derivative.hpp
+2
-28
No files found.
src/ngraph/distributed.cpp
View file @
6f1e728a
...
...
@@ -22,11 +22,6 @@
using
namespace
ngraph
;
NGRAPH_API
const
reduction
::
Type
reduction
::
sum
(
reduction
::
Type_t
::
sum
);
NGRAPH_API
const
reduction
::
Type
reduction
::
prod
(
reduction
::
Type_t
::
prod
);
NGRAPH_API
const
reduction
::
Type
reduction
::
min
(
reduction
::
Type_t
::
min
);
NGRAPH_API
const
reduction
::
Type
reduction
::
max
(
reduction
::
Type_t
::
max
);
std
::
ostream
&
reduction
::
operator
<<
(
std
::
ostream
&
out
,
const
reduction
::
Type
&
obj
)
{
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...
...
@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
obj
.
get_type
()
)
switch
(
obj
)
{
case
reduction
:
:
Type
_t
::
sum
:
out
<<
"sum
"
;
break
;
case
reduction
:
:
Type
_t
::
prod
:
out
<<
"prod
"
;
break
;
case
reduction
:
:
Type
_t
::
min
:
out
<<
"min
"
;
break
;
case
reduction
:
:
Type
_t
::
max
:
out
<<
"max
"
;
break
;
case
reduction
:
:
Type
::
SUM
:
out
<<
"SUM
"
;
break
;
case
reduction
:
:
Type
::
PROD
:
out
<<
"PROD
"
;
break
;
case
reduction
:
:
Type
::
MIN
:
out
<<
"MIN
"
;
break
;
case
reduction
:
:
Type
::
MAX
:
out
<<
"MAX
"
;
break
;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
...
...
@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
return
out
;
};
bool
reduction
::
Type
::
operator
==
(
const
reduction
::
Type
&
other
)
const
{
return
m_type
==
other
.
m_type
;
}
reduction
::
Type_t
reduction
::
Type
::
get_type
()
const
{
return
m_type
;
}
static
std
::
unique_ptr
<
DistributedInterface
>
s_distributed_interface
;
void
ngraph
::
set_distributed_interface
(
std
::
unique_ptr
<
DistributedInterface
>
distributed_interface
)
...
...
src/ngraph/distributed.hpp
View file @
6f1e728a
...
...
@@ -26,34 +26,15 @@ namespace ngraph
{
namespace
reduction
{
enum
class
Type
_t
enum
class
Type
{
sum
,
prod
,
min
,
max
,
SUM
,
PROD
,
MIN
,
MAX
,
};
class
Type
{
public
:
Type
(
const
Type_t
t
)
:
m_type
(
t
)
{
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Type
&
);
bool
operator
==
(
const
Type
&
other
)
const
;
bool
operator
!=
(
const
Type
&
other
)
const
{
return
!
(
*
this
==
other
);
}
Type_t
get_type
()
const
;
private
:
Type_t
m_type
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Type
&
obj
);
extern
NGRAPH_API
const
Type
sum
;
extern
NGRAPH_API
const
Type
prod
;
extern
NGRAPH_API
const
Type
min
;
extern
NGRAPH_API
const
Type
max
;
}
class
DistributedInterface
...
...
src/ngraph/distributed/mlsl.hpp
View file @
6f1e728a
...
...
@@ -92,14 +92,14 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
reduce_type
.
get_type
()
)
switch
(
reduce_type
)
{
case
reduction
:
:
Type
_t
::
sum
:
mlsl_reduce_type
=
MLSL
::
RT_SUM
;
break
;
case
reduction
:
:
Type
_t
::
prod
:
case
reduction
:
:
Type
::
SUM
:
mlsl_reduce_type
=
MLSL
::
RT_SUM
;
break
;
case
reduction
:
:
Type
::
PROD
:
throw
std
::
runtime_error
(
"MLSL doesn't support allreduce prod"
);
break
;
case
reduction
:
:
Type
_t
::
min
:
mlsl_reduce_type
=
MLSL
::
RT_MIN
;
break
;
case
reduction
:
:
Type
_t
::
max
:
mlsl_reduce_type
=
MLSL
::
RT_MAX
;
break
;
case
reduction
:
:
Type
::
MIN
:
mlsl_reduce_type
=
MLSL
::
RT_MIN
;
break
;
case
reduction
:
:
Type
::
MAX
:
mlsl_reduce_type
=
MLSL
::
RT_MAX
;
break
;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
...
...
src/ngraph/distributed/open_mpi.hpp
View file @
6f1e728a
...
...
@@ -104,12 +104,12 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
reduce_type
.
get_type
()
)
switch
(
reduce_type
)
{
case
reduction
:
:
Type
_t
::
sum
:
mpi_reduce_type
=
MPI_SUM
;
break
;
case
reduction
:
:
Type
_t
::
prod
:
mpi_reduce_type
=
MPI_PROD
;
break
;
case
reduction
:
:
Type
_t
::
min
:
mpi_reduce_type
=
MPI_MIN
;
break
;
case
reduction
:
:
Type
_t
::
max
:
mpi_reduce_type
=
MPI_MAX
;
break
;
case
reduction
:
:
Type
::
SUM
:
mpi_reduce_type
=
MPI_SUM
;
break
;
case
reduction
:
:
Type
::
PROD
:
mpi_reduce_type
=
MPI_PROD
;
break
;
case
reduction
:
:
Type
::
MIN
:
mpi_reduce_type
=
MPI_MIN
;
break
;
case
reduction
:
:
Type
::
MAX
:
mpi_reduce_type
=
MPI_MAX
;
break
;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
...
...
src/ngraph/graph_util.cpp
View file @
6f1e728a
...
...
@@ -233,6 +233,11 @@ std::list<std::shared_ptr<ngraph::Node>>
// There is a friendly name for this node so copy it
cloned_node
->
set_friendly_name
(
node
->
get_friendly_name
());
}
for
(
auto
tag
:
node
->
get_provenance_tags
())
{
cloned_node
->
add_provenance_tag
(
tag
);
}
node_map
[
node
.
get
()]
=
cloned_node
;
}
}
...
...
src/ngraph/node.cpp
View file @
6f1e728a
...
...
@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return
args
;
}
OutputVector
ngraph
::
as_output_vector
(
const
NodeVector
&
args
)
{
OutputVector
output_vector
;
for
(
auto
&
arg
:
check_single_output_args
(
args
))
{
output_vector
.
push_back
(
arg
);
}
return
output_vector
;
}
std
::
tuple
<
element
::
Type
,
PartialShape
>
Node
::
validate_and_infer_elementwise_args
(
const
op
::
AutoBroadcastSpec
&
autob
)
{
...
...
src/ngraph/node.hpp
View file @
6f1e728a
...
...
@@ -73,6 +73,8 @@ namespace ngraph
size_t
i
);
const
NodeVector
&
check_single_output_args
(
const
NodeVector
&
args
);
OutputVector
as_output_vector
(
const
NodeVector
&
args
);
/// Alias useful for cloning
using
NodeMap
=
std
::
unordered_map
<
ngraph
::
Node
*
,
std
::
shared_ptr
<
ngraph
::
Node
>>
;
...
...
@@ -487,7 +489,7 @@ namespace ngraph
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output
(
NodeType
*
node
,
size_t
index
)
:
m_node
(
node
)
:
m_node
(
node
->
shared_from_this
()
)
,
m_index
(
index
)
{
}
...
...
@@ -498,7 +500,7 @@ namespace ngraph
///
/// TODO: Make a plan to deprecate this.
Output
(
const
std
::
shared_ptr
<
NodeType
>&
node
,
size_t
index
)
:
m_node
(
node
.
get
()
)
:
m_node
(
node
)
,
m_index
(
index
)
{
}
...
...
@@ -511,12 +513,15 @@ namespace ngraph
{
}
// A null output
Output
()
=
default
;
/// \return A pointer to the node referred to by this output handle.
NodeType
*
get_node
()
const
{
return
m_node
;
}
NodeType
*
get_node
()
const
{
return
m_node
.
get
()
;
}
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std
::
shared_ptr
<
NodeType
>
get_node_shared_ptr
()
const
{
return
m_node
->
shared_from_this
()
;
}
std
::
shared_ptr
<
NodeType
>
get_node_shared_ptr
()
const
{
return
m_node
;
}
/// \return The index of the output referred to by this output handle.
size_t
get_index
()
const
{
return
m_index
;
}
/// \return A reference to the tensor descriptor for this output.
...
...
@@ -568,8 +573,8 @@ namespace ngraph
bool
operator
<=
(
const
Output
&
other
)
const
{
return
!
(
*
this
>
other
);
}
bool
operator
>=
(
const
Output
&
other
)
const
{
return
!
(
*
this
<
other
);
}
private
:
NodeType
*
const
m_node
;
const
size_t
m_index
;
std
::
shared_ptr
<
NodeType
>
m_node
;
size_t
m_index
{
0
}
;
};
inline
Input
<
Node
>
Node
::
input
(
size_t
input_index
)
...
...
src/ngraph/op/abs.cpp
View file @
6f1e728a
...
...
@@ -23,10 +23,6 @@ using namespace ngraph;
const
string
op
::
Abs
::
type_name
{
"Abs"
};
op
::
Abs
::
Abs
()
{
}
op
::
Abs
::
Abs
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
...
...
src/ngraph/op/abs.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an absolute value operation.
Abs
();
Abs
()
=
default
;
/// \brief Constructs an absolute value operation.
///
...
...
src/ngraph/op/acos.cpp
View file @
6f1e728a
...
...
@@ -34,10 +34,6 @@ using namespace ngraph;
const
string
op
::
Acos
::
type_name
{
"Acos"
};
op
::
Acos
::
Acos
()
{
}
op
::
Acos
::
Acos
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
...
...
src/ngraph/op/acos.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an arccos operation.
Acos
();
Acos
()
=
default
;
/// \brief Constructs an arccos operation.
///
/// \param arg Output that produces the input tensor.<br>
...
...
src/ngraph/op/add.cpp
View file @
6f1e728a
...
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
Add
::
type_name
{
"Add"
};
op
::
Add
::
Add
()
{
}
op
::
Add
::
Add
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
...
...
src/ngraph/op/add.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an unitialized addition operation
Add
();
Add
()
=
default
;
/// \brief Constructs an addition operation.
///
...
...
src/ngraph/op/all.cpp
View file @
6f1e728a
...
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
All
::
type_name
{
"All"
};
op
::
All
::
All
()
{
}
op
::
All
::
All
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
LogicalReduction
(
arg
,
reduction_axes
)
{
...
...
src/ngraph/op/all.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an "all" reduction operation.
All
();
All
()
=
default
;
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/allreduce.cpp
View file @
6f1e728a
...
...
@@ -21,13 +21,8 @@ using namespace ngraph;
const
string
op
::
AllReduce
::
type_name
{
"AllReduce"
};
op
::
AllReduce
::
AllReduce
()
:
m_reduce_type
(
reduction
::
sum
)
{
}
op
::
AllReduce
::
AllReduce
(
const
shared_ptr
<
Node
>&
arg
,
const
reduction
::
Type
reduce_type
)
:
Op
(
check_single_output_args
({
arg
}))
op
::
AllReduce
::
AllReduce
(
const
Output
<
Node
>&
arg
,
reduction
::
Type
reduce_type
)
:
Op
({
arg
})
,
m_reduce_type
(
reduce_type
)
{
constructor_validate_and_infer_types
();
...
...
@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const
{
return
m_reduce_type
;
}
void
op
::
AllReduce
::
set_reduce_type
(
reduction
::
Type
reduce_type
)
{
m_reduce_type
=
reduce_type
;
}
src/ngraph/op/allreduce.hpp
View file @
6f1e728a
...
...
@@ -29,17 +29,17 @@ namespace ngraph
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
AllReduce
();
AllReduce
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
reduction
::
Type
reduce_type
=
reduction
::
sum
);
AllReduce
()
=
default
;
AllReduce
(
const
Output
<
Node
>&
arg
,
reduction
::
Type
reduce_type
=
reduction
::
Type
::
SUM
);
void
validate_and_infer_types
()
override
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
reduction
::
Type
get_reduce_type
()
const
;
void
set_reduce_type
(
reduction
::
Type
reduce_type
);
private
:
const
reduction
::
Type
m_reduce_type
;
reduction
::
Type
m_reduce_type
{
reduction
::
Type
::
SUM
}
;
};
}
}
src/ngraph/op/and.cpp
View file @
6f1e728a
...
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
And
::
type_name
{
"And"
};
op
::
And
::
And
()
{
}
op
::
And
::
And
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseLogical
(
arg0
,
arg1
,
autob
)
{
...
...
src/ngraph/op/and.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a logical-and operation.
And
();
And
()
=
default
;
/// \brief Constructs a logical-and operation.
///
...
...
src/ngraph/op/any.cpp
View file @
6f1e728a
...
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
Any
::
type_name
{
"Any"
};
op
::
Any
::
Any
()
{
}
op
::
Any
::
Any
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
LogicalReduction
(
arg
,
reduction_axes
)
{
...
...
src/ngraph/op/any.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an "any" reduction operation.
Any
();
Any
()
=
default
;
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/argmax.cpp
View file @
6f1e728a
...
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
ArgMax
::
type_name
{
"ArgMax"
};
op
::
ArgMax
::
ArgMax
()
{
}
op
::
ArgMax
::
ArgMax
(
const
Output
<
Node
>&
arg
,
size_t
axis
,
const
element
::
Type
&
index_element_type
)
:
op
::
util
::
IndexReduction
(
arg
,
axis
,
index_element_type
)
{
...
...
src/ngraph/op/argmax.hpp
View file @
6f1e728a
...
...
@@ -32,7 +32,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a ArgMax operation.
ArgMax
();
ArgMax
()
=
default
;
/// \brief Constructs a ArgMax operation.
///
/// \param arg The input tensor
...
...
src/ngraph/op/argmin.cpp
View file @
6f1e728a
...
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
ArgMin
::
type_name
{
"ArgMin"
};
op
::
ArgMin
::
ArgMin
()
{
}
op
::
ArgMin
::
ArgMin
(
const
Output
<
Node
>&
arg
,
size_t
axis
,
const
element
::
Type
&
index_element_type
)
:
op
::
util
::
IndexReduction
(
arg
,
axis
,
index_element_type
)
{
...
...
src/ngraph/op/argmin.hpp
View file @
6f1e728a
...
...
@@ -32,7 +32,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a ArgMin operation.
ArgMin
();
ArgMin
()
=
default
;
/// \brief Constructs a ArgMin operation.
///
...
...
src/ngraph/op/asin.cpp
View file @
6f1e728a
...
...
@@ -33,10 +33,6 @@ using namespace ngraph;
const
string
op
::
Asin
::
type_name
{
"Asin"
};
op
::
Asin
::
Asin
()
{
}
op
::
Asin
::
Asin
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
...
...
src/ngraph/op/asin.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an arcsin operation.
Asin
();
Asin
()
=
default
;
/// \brief Constructs an arcsin operation.
///
/// \param arg Output that produces the input tensor.<br>
...
...
src/ngraph/op/atan.cpp
View file @
6f1e728a
...
...
@@ -32,10 +32,6 @@ using namespace ngraph;
const
string
op
::
Atan
::
type_name
{
"Atan"
};
op
::
Atan
::
Atan
()
{
}
op
::
Atan
::
Atan
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
...
...
src/ngraph/op/atan.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an arctan operation.
Atan
();
Atan
()
=
default
;
/// \brief Constructs an arctan operation.
///
...
...
src/ngraph/op/avg_pool.cpp
View file @
6f1e728a
...
...
@@ -23,10 +23,6 @@ using namespace ngraph;
const
string
op
::
AvgPool
::
type_name
{
"AvgPool"
};
op
::
AvgPool
::
AvgPool
()
{
}
op
::
AvgPool
::
AvgPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
...
...
@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
const
string
op
::
AvgPoolBackprop
::
type_name
(
"AvgPoolBackprop"
);
op
::
AvgPoolBackprop
::
AvgPoolBackprop
()
{
}
op
::
AvgPoolBackprop
::
AvgPoolBackprop
(
const
Shape
&
forward_arg_shape
,
const
shared_ptr
<
Node
>&
delta
,
const
Shape
&
window_shape
,
...
...
src/ngraph/op/avg_pool.hpp
View file @
6f1e728a
...
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a batched average pooling operation.
AvgPool
();
AvgPool
()
=
default
;
/// \brief Constructs a batched average pooling operation.
///
...
...
@@ -175,7 +175,7 @@ namespace ngraph
public
:
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
AvgPoolBackprop
();
AvgPoolBackprop
()
=
default
;
AvgPoolBackprop
(
const
Shape
&
forward_arg_shape
,
const
std
::
shared_ptr
<
Node
>&
delta
,
const
Shape
&
window_shape
,
...
...
src/ngraph/op/batch_norm.cpp
View file @
6f1e728a
...
...
@@ -22,11 +22,13 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
const
std
::
string
ngraph
::
op
::
BatchNormTraining
::
type_name
{
"BatchNormTraining"
};
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
double
epsilon
)
:
Op
(
"BatchNormTraining"
,
check_single_output_args
({
gamma
,
beta
,
input
})
)
:
Op
(
{
gamma
,
beta
,
input
}
)
,
m_epsilon
(
epsilon
)
{
constructor_validate_and_infer_types
();
...
...
@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
// DEPRECATED
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
)
:
Op
(
"BatchNormTraining"
,
check_single_output_args
({
gamma
,
beta
,
input
})
)
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
)
:
Op
(
{
gamma
,
beta
,
input
}
)
,
m_epsilon
(
eps
)
{
constructor_validate_and_infer_types
();
...
...
@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints
.
add_delta
(
beta
,
dbeta
);
}
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
const
std
::
string
ngraph
::
op
::
BatchNormInference
::
type_name
{
"BatchNormInference"
};
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
double
epsilon
)
:
Op
(
"BatchNormInference"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
})
)
:
Op
(
{
gamma
,
beta
,
input
,
mean
,
variance
}
)
,
m_epsilon
(
epsilon
)
{
constructor_validate_and_infer_types
();
...
...
@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
// DEPRECATED
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
)
:
Op
(
"BatchNormInference"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
})
)
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
)
:
Op
(
{
gamma
,
beta
,
input
,
mean
,
variance
}
)
,
m_epsilon
(
eps
)
{
constructor_validate_and_infer_types
();
...
...
@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
new_args
.
at
(
2
),
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
m_epsilon
);
}
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
bet
a
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
std
::
shared_ptr
<
ngraph
::
Node
>
delta
,
double
epsilon
)
:
Op
(
"BatchNormTrainingBackprop"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
})
)
const
std
::
string
ngraph
::
op
::
BatchNormTrainingBackprop
::
type_name
{
"BatchNormTrainingBackprop"
};
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamm
a
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
delta
,
double
epsilon
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
}
)
,
m_epsilon
(
epsilon
)
{
...
...
@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
constructor_validate_and_infer_types
();
}
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
double
epsilon
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
std
::
shared_ptr
<
ngraph
::
Node
>
delta
)
:
Op
(
"BatchNormTrainingBackprop"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
}))
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
double
epsilon
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
delta
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
})
,
m_epsilon
(
epsilon
)
{
...
...
src/ngraph/op/batch_norm.hpp
View file @
6f1e728a
...
...
@@ -31,13 +31,17 @@ namespace ngraph
class
BatchNormTraining
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BatchNormTraining
()
=
default
;
/// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining
(
std
::
shared_ptr
<
Node
>
input
,
std
::
shared_ptr
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
BatchNormTraining
(
Output
<
Node
>
input
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
double
epsilon
);
NGRAPH_DEPRECATED_DOC
...
...
@@ -62,13 +66,14 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormTraining
(
double
eps
,
std
::
shared_ptr
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
input
);
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
input
);
void
validate_and_infer_types
()
override
;
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
@@ -87,17 +92,20 @@ namespace ngraph
class
BatchNormInference
:
public
Op
{
public
:
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BatchNormInference
()
=
default
;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
BatchNormInference
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
double
epsilon
);
NGRAPH_DEPRECATED_DOC
...
...
@@ -120,15 +128,16 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormInference
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
);
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
);
void
validate_and_infer_types
()
override
;
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
@@ -152,28 +161,33 @@ namespace ngraph
class
BatchNormTrainingBackprop
:
public
Op
{
public
:
BatchNormTrainingBackprop
(
std
::
shared_ptr
<
Node
>
input
,
std
::
shared_ptr
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
mean
,
std
::
shared_ptr
<
Node
>
variance
,
std
::
shared_ptr
<
Node
>
delta
,
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BatchNormTrainingBackprop
()
=
default
;
BatchNormTrainingBackprop
(
Output
<
Node
>
input
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
mean
,
Output
<
Node
>
variance
,
Output
<
Node
>
delta
,
double
epsilon
);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormTrainingBackprop
(
double
epsilon
,
std
::
shared_ptr
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
input
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
input
,
std
::
shared_ptr
<
Node
>
mean
,
std
::
shared_ptr
<
Node
>
variance
,
std
::
shared_ptr
<
Node
>
delta
);
Output
<
Node
>
mean
,
Output
<
Node
>
variance
,
Output
<
Node
>
delta
);
void
validate_and_infer_types
()
override
;
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/broadcast.cpp
View file @
6f1e728a
...
...
@@ -20,21 +20,20 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Broadcast
::
Broadcast
(
const
std
::
string
&
name
,
const
NodeVector
&
args
,
const
string
op
::
Broadcast
::
type_name
{
"Broadcast"
};
op
::
Broadcast
::
Broadcast
(
const
OutputVector
&
args
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
Op
(
name
,
check_single_output_args
(
args
)
)
:
Op
(
args
)
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
{
constructor_validate_and_infer_types
();
}
op
::
Broadcast
::
Broadcast
(
const
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
Broadcast
(
"Broadcast"
,
{
arg
},
shape
,
broadcast_axes
)
op
::
Broadcast
::
Broadcast
(
const
Output
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
:
Broadcast
(
OutputVector
{
arg
},
shape
,
broadcast_axes
)
{
}
...
...
@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Sum
>
(
delta
,
m_broadcast_axes
));
}
op
::
BroadcastLike
::
BroadcastLike
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
like_arg
,
const
string
op
::
BroadcastLike
::
type_name
{
"BroadcastLike"
};
op
::
BroadcastLike
::
BroadcastLike
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
like_arg
,
const
AxisSet
&
initial_broadcast_axes
)
:
Broadcast
(
"BroadcastLike"
,
{
arg
,
like_arg
},
{},
{})
:
Broadcast
({
arg
,
like_arg
},
{},
{})
,
m_initial_broadcast_axes
(
initial_broadcast_axes
)
{
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/broadcast.hpp
View file @
6f1e728a
...
...
@@ -27,15 +27,18 @@ namespace ngraph
class
Broadcast
:
public
Op
{
public
:
/// \brief Constructs a conversion operation.
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a broadcast operation.
Broadcast
()
=
default
;
/// \brief Constructs a broadcast operation.
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg.
Broadcast
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
Broadcast
(
const
Output
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
void
validate_and_infer_types
()
override
;
...
...
@@ -44,12 +47,14 @@ namespace ngraph
/// \return A set containing the indices of the broadcast axes (0-based).
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
void
set_broadcast_axes
(
const
AxisSet
&
broadcast_axes
)
{
m_broadcast_axes
=
broadcast_axes
;
}
const
Shape
&
get_broadcast_shape
()
const
{
return
m_shape
;
}
void
set_broadcast_shape
(
const
Shape
&
shape
)
{
m_shape
=
shape
;
}
protected
:
Broadcast
(
const
std
::
string
&
node_type
,
const
NodeVector
&
args
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
Broadcast
(
const
OutputVector
&
args
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
...
...
@@ -63,6 +68,11 @@ namespace ngraph
class
BroadcastLike
:
public
Broadcast
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike
()
=
default
;
/// \brief Broadcast arg to the same shape as like_arg.
///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
...
...
@@ -72,8 +82,8 @@ namespace ngraph
/// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
BroadcastLike
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
like_arg
,
BroadcastLike
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
like_arg
,
const
AxisSet
&
initial_broadcast_axes
);
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -81,6 +91,11 @@ namespace ngraph
void
infer_shape
()
override
;
const
AxisSet
&
get_initial_broadcast_axes
()
const
{
return
m_initial_broadcast_axes
;
}
void
set_initial_broadcast_axes
(
const
AxisSet
&
initial_broadcast_axes
)
{
m_initial_broadcast_axes
=
initial_broadcast_axes
;
}
protected
:
AxisSet
m_initial_broadcast_axes
;
};
...
...
src/ngraph/op/broadcast_distributed.cpp
View file @
6f1e728a
...
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
BroadcastDistributed
::
BroadcastDistributed
(
const
shared_ptr
<
Node
>&
arg
,
int
root_id
)
:
Op
(
"BroadcastDistributed"
,
check_single_output_args
({
arg
}))
const
string
op
::
BroadcastDistributed
::
type_name
{
"BroadcastDistributed"
};
op
::
BroadcastDistributed
::
BroadcastDistributed
(
const
Output
<
Node
>&
arg
,
int
root_id
)
:
Op
({
arg
})
,
m_root_id
(
root_id
)
{
constructor_validate_and_infer_types
();
...
...
@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
{
return
m_root_id
;
}
void
op
::
BroadcastDistributed
::
set_root_id
(
int
root_id
)
{
m_root_id
=
root_id
;
}
src/ngraph/op/broadcast_distributed.hpp
View file @
6f1e728a
...
...
@@ -27,16 +27,21 @@ namespace ngraph
class
BroadcastDistributed
:
public
Op
{
public
:
BroadcastDistributed
(
const
std
::
shared_ptr
<
Node
>&
arg
,
int
root_id
=
0
);
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BroadcastDistributed
()
=
default
;
BroadcastDistributed
(
const
Output
<
Node
>&
arg
,
int
root_id
=
0
);
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
int
get_root_id
()
const
;
void
set_root_id
(
int
root_id
);
private
:
const
int
m_root_id
;
int
m_root_id
;
};
}
}
src/ngraph/op/ceiling.cpp
View file @
6f1e728a
...
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Ceiling
::
Ceiling
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Ceiling"
,
arg
)
const
string
op
::
Ceiling
::
type_name
{
"Ceiling"
};
op
::
Ceiling
::
Ceiling
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/ceiling.hpp
View file @
6f1e728a
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Ceiling
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a ceiling operation.
Ceiling
()
=
default
;
/// \brief Constructs a ceiling operation.
///
/// \param arg Node that produces the input tensor.
Ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Ceiling
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/concat.cpp
View file @
6f1e728a
...
...
@@ -22,13 +22,20 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Concat
::
Concat
(
const
NodeVector
&
args
,
size_t
concatenation_axis
)
:
Op
(
"Concat"
,
check_single_output_args
(
args
))
const
string
op
::
Concat
::
type_name
{
"Concat"
};
op
::
Concat
::
Concat
(
const
OutputVector
&
args
,
size_t
concatenation_axis
)
:
Op
(
args
)
,
m_concatenation_axis
(
concatenation_axis
)
{
constructor_validate_and_infer_types
();
}
op
::
Concat
::
Concat
(
const
NodeVector
&
args
,
size_t
concatenation_axis
)
:
Concat
(
as_output_vector
(
args
),
concatenation_axis
)
{
}
void
op
::
Concat
::
validate_and_infer_types
()
{
NODE_VALIDATION_CHECK
(
this
,
get_input_size
()
>=
1
,
"At least one argument required."
);
...
...
src/ngraph/op/concat.hpp
View file @
6f1e728a
...
...
@@ -28,6 +28,17 @@ namespace ngraph
class
Concat
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a concatenation operation.
Concat
()
=
default
;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat
(
const
OutputVector
&
args
,
size_t
concatenation_axis
);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
...
...
@@ -41,10 +52,15 @@ namespace ngraph
/// \return The concatenation axis.
size_t
get_concatenation_axis
()
const
{
return
m_concatenation_axis
;
}
void
set_concatenation_axis
(
size_t
concatenation_axis
)
{
m_concatenation_axis
=
concatenation_axis
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
size_t
m_concatenation_axis
;
size_t
m_concatenation_axis
;
};
}
}
src/ngraph/op/constant.cpp
View file @
6f1e728a
...
...
@@ -45,6 +45,8 @@ string to_cpp_string(T value)
return
rc
;
}
const
string
op
::
Constant
::
type_name
{
"Constant"
};
op
::
Constant
::~
Constant
()
{
}
...
...
src/ngraph/op/constant.hpp
View file @
6f1e728a
...
...
@@ -34,6 +34,9 @@ namespace ngraph
class
Constant
:
public
Node
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a tensor constant.
///
/// \param type The element type of the tensor constant.
...
...
@@ -78,7 +81,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data.
Constant
(
const
element
::
Type
&
type
,
Shape
shape
,
const
std
::
vector
<
std
::
string
>&
values
)
:
Node
(
"Constant"
,
{})
:
Node
({})
,
m_element_type
(
type
)
,
m_shape
(
shape
)
,
m_data
(
new
runtime
::
AlignedBuffer
(
shape_size
(
m_shape
)
*
m_element_type
.
size
(),
...
...
@@ -135,7 +138,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param data A void* to constant data.
Constant
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
const
void
*
data
)
:
Node
(
"Constant"
,
{})
:
Node
({})
,
m_element_type
(
type
)
,
m_shape
(
shape
)
,
m_data
(
nullptr
)
...
...
src/ngraph/op/convert.cpp
View file @
6f1e728a
...
...
@@ -21,8 +21,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Convert
::
Convert
(
const
shared_ptr
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
:
Op
(
"Convert"
,
check_single_output_args
({
arg
}))
const
string
op
::
Convert
::
type_name
{
"Convert"
};
op
::
Convert
::
Convert
(
const
Output
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
:
Op
({
arg
})
,
m_element_type
(
element_type
)
{
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/convert.hpp
View file @
6f1e728a
...
...
@@ -26,11 +26,16 @@ namespace ngraph
class
Convert
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a conversion operation.
Convert
()
=
default
;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
Convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
Convert
(
const
Output
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
void
validate_and_infer_types
()
override
;
...
...
@@ -38,8 +43,13 @@ namespace ngraph
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
element
::
Type
&
get_convert_element_type
()
const
{
return
m_element_type
;
}
void
set_convert_element_type
(
const
element
::
Type
&
element_type
)
{
m_element_type
=
element_type
;
}
protected
:
const
ngraph
::
element
::
Type
m_element_type
;
ngraph
::
element
::
Type
m_element_type
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
};
...
...
src/ngraph/op/convolution.cpp
View file @
6f1e728a
...
...
@@ -27,15 +27,17 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
const
string
op
::
Convolution
::
type_name
{
"Convolution"
};
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_above
,
const
Strides
&
data_dilation_strides
,
const
PadType
&
pad_type
)
:
Op
(
"Convolution"
,
check_single_output_args
({
data_batch
,
filters
})
)
:
Op
(
{
data_batch
,
filters
}
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_window_dilation_strides
(
window_dilation_strides
)
,
m_padding_below
(
padding_below
)
...
...
@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_shape
);
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
...
...
@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
)
:
Convolution
(
data_batch
,
...
...
@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
)
:
Convolution
(
data_batch
,
filters
,
...
...
@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
)
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
)
:
Convolution
(
data_batch
,
filters
,
Strides
(),
Strides
(),
CoordinateDiff
(),
CoordinateDiff
())
{
}
...
...
@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
m_data_dilation_strides
));
}
const
string
op
::
ConvolutionBackpropData
::
type_name
{
"ConvolutionBackpropData"
};
op
::
ConvolutionBackpropData
::
ConvolutionBackpropData
(
const
Shape
&
data_batch_shape
,
const
shared_ptr
<
Node
>&
filters
,
const
shared_ptr
<
Node
>&
output_delta
,
const
Output
<
Node
>&
filters
,
const
Output
<
Node
>&
output_delta
,
const
Strides
&
window_movement_strides_forward
,
const
Strides
&
window_dilation_strides_forward
,
const
CoordinateDiff
&
padding_below_forward
,
const
CoordinateDiff
&
padding_above_forward
,
const
Strides
&
data_dilation_strides_forward
)
:
Op
(
"ConvolutionBackpropData"
,
check_single_output_args
({
filters
,
output_delta
})
)
:
Op
(
{
filters
,
output_delta
}
)
,
m_data_batch_shape
(
data_batch_shape
)
,
m_window_movement_strides_forward
(
window_movement_strides_forward
)
,
m_window_dilation_strides_forward
(
window_dilation_strides_forward
)
...
...
@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
m_data_dilation_strides_forward
[
i
]);
}
auto
swap_NC
=
[](
const
shared_ptr
<
Node
>
n
)
{
AxisVector
ax_order
=
ngraph
::
get_default_order
(
n
->
get_shape
());
auto
swap_NC
=
[](
const
Output
<
Node
>&
n
)
{
AxisVector
ax_order
=
ngraph
::
get_default_order
(
n
.
get_shape
());
ax_order
[
0
]
=
1
;
ax_order
[
1
]
=
0
;
auto
new_shape
=
n
->
get_shape
();
new_shape
[
0
]
=
n
->
get_shape
()[
1
];
new_shape
[
1
]
=
n
->
get_shape
()[
0
];
auto
new_shape
=
n
.
get_shape
();
new_shape
[
0
]
=
n
.
get_shape
()[
1
];
new_shape
[
1
]
=
n
.
get_shape
()[
0
];
return
make_shared
<
op
::
Reshape
>
(
n
,
ax_order
,
new_shape
);
};
...
...
@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
return
backward_delta_out_pad_above
;
}
const
string
op
::
ConvolutionBackpropFilters
::
type_name
{
"ConvolutionBackpropFilters"
};
op
::
ConvolutionBackpropFilters
::
ConvolutionBackpropFilters
(
const
shared_ptr
<
Node
>&
data_batch
,
const
Output
<
Node
>&
data_batch
,
const
Shape
&
filters_shape
,
const
shared_ptr
<
Node
>&
output_delta
,
const
Output
<
Node
>&
output_delta
,
const
Strides
&
window_movement_strides_forward
,
const
Strides
&
window_dilation_strides_forward
,
const
CoordinateDiff
&
padding_below_forward
,
const
CoordinateDiff
&
padding_above_forward
,
const
Strides
&
data_dilation_strides_forward
)
:
Op
(
"ConvolutionBackpropFilters"
,
check_single_output_args
({
data_batch
,
output_delta
})
)
:
Op
(
{
data_batch
,
output_delta
}
)
,
m_filters_shape
(
filters_shape
)
,
m_window_movement_strides_forward
(
window_movement_strides_forward
)
,
m_window_dilation_strides_forward
(
window_dilation_strides_forward
)
...
...
src/ngraph/op/convolution.hpp
View file @
6f1e728a
This diff is collapsed.
Click to expand it.
src/ngraph/op/cos.cpp
View file @
6f1e728a
...
...
@@ -22,8 +22,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Cos
::
Cos
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Cos"
,
arg
)
const
string
op
::
Cos
::
type_name
{
"Cos"
};
op
::
Cos
::
Cos
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/cos.hpp
View file @
6f1e728a
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Cos
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a cosine operation.
Cos
()
=
default
;
/// \brief Constructs a cosine operation.
///
/// \param arg Node that produces the input tensor.
Cos
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Cos
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/cosh.cpp
View file @
6f1e728a
...
...
@@ -21,8 +21,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Cosh
::
Cosh
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Cosh"
,
arg
)
const
string
op
::
Cosh
::
type_name
{
"Cosh"
};
op
::
Cosh
::
Cosh
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/cosh.hpp
View file @
6f1e728a
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Cosh
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a hyperbolic cosine operation.
Cosh
()
=
default
;
/// \brief Constructs a hyperbolic cosine operation.
///
/// \param arg Node that produces the input tensor.
Cosh
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Cosh
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/dequantize.cpp
View file @
6f1e728a
...
...
@@ -20,13 +20,15 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Dequantize
::
Dequantize
(
const
shared_ptr
<
Node
>&
input
,
const
shared_ptr
<
Node
>&
scale
,
const
shared_ptr
<
Node
>&
zero_point
,
const
string
op
::
Dequantize
::
type_name
{
"Dequantize"
};
op
::
Dequantize
::
Dequantize
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
scale
,
const
Output
<
Node
>&
zero_point
,
const
element
::
Type
&
type
,
const
AxisSet
&
axes
)
:
Op
(
"Dequantize"
,
check_single_output_args
({
input
,
scale
,
zero_point
})
)
:
Op
(
{
input
,
scale
,
zero_point
}
)
,
m_type
(
type
)
,
m_axes
(
axes
)
{
...
...
src/ngraph/op/dequantize.hpp
View file @
6f1e728a
...
...
@@ -30,31 +30,40 @@ namespace ngraph
class
Dequantize
:
public
ngraph
::
op
::
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a Dequantize operation
Dequantize
()
=
default
;
/// \brief Constructs a Dequantize operation
/// \param input quantized input
/// \param scale scale used for mapping
/// \param zero_point zero point used for mapping
/// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize
(
const
std
::
shared_ptr
<
Node
>&
input
,
const
std
::
shared_ptr
<
Node
>&
scale
,
const
std
::
shared_ptr
<
Node
>&
zero_point
,
const
ngraph
::
element
::
Type
&
type
,
const
ngraph
::
AxisSet
&
axes
);
Dequantize
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
scale
,
const
Output
<
Node
>&
zero_point
,
const
element
::
Type
&
type
,
const
AxisSet
&
axes
);
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
ngraph
::
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
const
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
void
set_axes
(
const
AxisSet
&
axes
)
{
m_axes
=
axes
;
}
const
element
::
Type
&
get_type
()
const
{
return
m_type
;
}
void
set_type
(
const
element
::
Type
&
type
)
{
m_type
=
type
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
private
:
ngraph
::
element
::
Type
m_type
;
ngraph
::
AxisSet
m_axes
;
element
::
Type
m_type
;
AxisSet
m_axes
;
};
}
}
src/ngraph/op/divide.cpp
View file @
6f1e728a
...
...
@@ -21,20 +21,21 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Divide
::
Divide
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
Divide
::
type_name
{
"Divide"
};
op
::
Divide
::
Divide
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Divide"
,
arg0
,
arg1
,
autob
)
,
m_pythondiv
(
true
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
op
::
Divide
::
Divide
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
op
::
Divide
::
Divide
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
bool
pythondiv
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Divide"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
,
m_pythondiv
(
pythondiv
)
{
constructor_validate_and_infer_types
();
...
...
@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints
.
add_delta
(
y
,
-
delta
*
shared_from_this
()
/
y
);
}
shared_ptr
<
Node
>
ngraph
::
operator
/
(
const
shared_ptr
<
Node
>
arg0
,
const
shared_ptr
<
Node
>
arg1
)
shared_ptr
<
Node
>
ngraph
::
operator
/
(
const
Output
<
Node
>
arg0
,
const
Output
<
Node
>
arg1
)
{
return
make_shared
<
op
::
Divide
>
(
arg0
,
arg1
);
}
src/ngraph/op/divide.hpp
View file @
6f1e728a
...
...
@@ -26,14 +26,19 @@ namespace ngraph
class
Divide
:
public
util
::
BinaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a division operation.
Divide
()
=
default
;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param autob Auto broadcast specification
Divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Divide
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
bool
pythondiv
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
...
...
@@ -42,11 +47,12 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Divide
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
bool
is_pythondiv
()
const
{
return
m_pythondiv
;
}
void
set_is_pythondiv
(
bool
pythondiv
)
{
m_pythondiv
=
pythondiv
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
@@ -54,10 +60,10 @@ namespace ngraph
const
NodeVector
&
deltas
)
override
;
protected
:
bool
m_pythondiv
;
bool
m_pythondiv
{
true
}
;
};
}
std
::
shared_ptr
<
ngraph
::
Node
>
operator
/
(
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg0
,
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg1
);
std
::
shared_ptr
<
ngraph
::
Node
>
operator
/
(
const
Output
<
ngraph
::
Node
>
arg0
,
const
Output
<
ngraph
::
Node
>
arg1
);
}
src/ngraph/op/dot.cpp
View file @
6f1e728a
...
...
@@ -29,16 +29,18 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Dot
::
Dot
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
)
const
string
op
::
Dot
::
type_name
{
"Dot"
};
op
::
Dot
::
Dot
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
)
:
Dot
(
arg0
,
arg1
,
0
,
false
)
{
}
op
::
Dot
::
Dot
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
op
::
Dot
::
Dot
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
size_t
reduction_axes_count
,
bool
has_reduction_axes_count
)
:
Op
(
"Dot"
,
check_single_output_args
({
arg0
,
arg1
})
)
:
Op
(
{
arg0
,
arg1
}
)
,
m_reduction_axes_count
(
reduction_axes_count
)
,
m_has_reduction_axes_count
(
has_reduction_axes_count
)
{
...
...
@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_shape
);
}
shared_ptr
<
op
::
Reshape
>
make_reshape_axes_to_front
(
const
shared_ptr
<
Node
>&
n
,
shared_ptr
<
op
::
Reshape
>
make_reshape_axes_to_front
(
const
Output
<
Node
>&
n
,
const
Shape
&
front_shape
,
const
Shape
&
back_shape
)
{
...
...
src/ngraph/op/dot.hpp
View file @
6f1e728a
...
...
@@ -28,13 +28,18 @@ namespace ngraph
class
Dot
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a dot product operation.
Dot
()
=
default
;
/// \brief Constructs a dot product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot.
Dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Dot
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
size_t
reduction_axes_count
,
bool
has_reduction_axes_count
=
true
);
...
...
@@ -48,11 +53,20 @@ namespace ngraph
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
Dot
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
);
void
validate_and_infer_types
()
override
;
size_t
get_reduction_axes_count
()
const
{
return
m_reduction_axes_count
;
}
void
get_reduction_axes_count
(
size_t
reduction_axes_count
)
{
m_reduction_axes_count
=
reduction_axes_count
;
}
bool
get_has_reduction_axes_count
()
const
{
return
m_has_reduction_axes_count
;
}
void
set_has_reduction_axes_count
(
bool
has_reduction_axes_count
)
{
m_has_reduction_axes_count
=
has_reduction_axes_count
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
{
...
...
src/ngraph/op/embedding_lookup.cpp
View file @
6f1e728a
...
...
@@ -19,6 +19,8 @@
using
namespace
std
;
using
namespace
ngraph
;
const
string
op
::
EmbeddingLookup
::
type_name
{
"EmbeddingLookup"
};
void
op
::
EmbeddingLookup
::
validate_and_infer_types
()
{
element
::
Type
result_et
=
get_input_element_type
(
1
);
...
...
src/ngraph/op/embedding_lookup.hpp
View file @
6f1e728a
...
...
@@ -28,6 +28,11 @@ namespace ngraph
class
EmbeddingLookup
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a EmbeddingLookup operation.
EmbeddingLookup
()
=
default
;
/// \brief Constructs a EmbeddingLookup operation.
///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
...
...
@@ -36,8 +41,8 @@ namespace ngraph
/// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup
(
const
std
::
shared_ptr
<
Node
>&
data
,
const
std
::
shared_ptr
<
Node
>&
weights
)
:
Op
(
"EmbeddingLookup"
,
check_single_output_args
({
data
,
weights
})
)
EmbeddingLookup
(
const
Output
<
Node
>&
data
,
const
Output
<
Node
>&
weights
)
:
Op
(
{
data
,
weights
}
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/equal.cpp
View file @
6f1e728a
...
...
@@ -19,10 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Equal
::
Equal
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"Equal"
,
arg0
,
arg1
,
autob
)
const
string
op
::
Equal
::
type_name
{
"Equal"
};
op
::
Equal
::
Equal
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/equal.hpp
View file @
6f1e728a
...
...
@@ -40,13 +40,18 @@ namespace ngraph
class
Equal
:
public
util
::
BinaryElementwiseComparison
{
public
:
/// \brief Constructs an is-equal operation.
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an equal operation.
Equal
()
=
default
;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Equal
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Equal
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/erf.cpp
View file @
6f1e728a
...
...
@@ -21,14 +21,16 @@
using
namespace
std
;
using
namespace
ngraph
;
const
string
op
::
Erf
::
type_name
{
"Erf"
};
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
Erf
>
(
new_args
.
at
(
0
));
}
op
::
Erf
::
Erf
(
shared_ptr
<
Node
>
arg
)
:
UnaryElementwiseArithmetic
(
"Erf"
,
arg
)
op
::
Erf
::
Erf
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
src/ngraph/op/erf.hpp
View file @
6f1e728a
...
...
@@ -27,7 +27,11 @@ namespace ngraph
class
Erf
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
Erf
(
std
::
shared_ptr
<
Node
>
arg
);
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Erf
()
=
default
;
Erf
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/exp.cpp
View file @
6f1e728a
...
...
@@ -20,8 +20,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Exp
::
Exp
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Exp"
,
arg
)
const
string
op
::
Exp
::
type_name
{
"Exp"
};
op
::
Exp
::
Exp
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/exp.hpp
View file @
6f1e728a
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Exp
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an exponential operation.
Exp
()
=
default
;
/// \brief Constructs an exponential operation.
///
/// \param arg Node that produces the input tensor.
Exp
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Exp
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/floor.cpp
View file @
6f1e728a
...
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Floor
::
Floor
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Floor"
,
arg
)
const
string
op
::
Floor
::
type_name
{
"Floor"
};
op
::
Floor
::
Floor
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/floor.hpp
View file @
6f1e728a
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Floor
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a floor operation.
Floor
()
=
default
;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Floor
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Floor
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/fused/split.cpp
View file @
6f1e728a
...
...
@@ -72,6 +72,13 @@ void op::Split::pre_validate_and_infer_types()
dimension_at_axis
,
" has to be equal to the sum of splits passed to the op: "
,
sum_splits
);
const
bool
all_splits_positive
=
all_of
(
begin
(
m_splits
),
end
(
m_splits
),
[](
const
size_t
v
)
{
return
v
>
0
;
});
NODE_VALIDATION_CHECK
(
this
,
all_splits_positive
==
true
,
"All values of the 'splits' attribute must be greater than zero"
);
}
}
...
...
src/ngraph/op/reshape.cpp
View file @
6f1e728a
...
...
@@ -24,10 +24,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Reshape
::
Reshape
(
const
shared_ptr
<
Node
>&
arg
,
const
string
op
::
Reshape
::
type_name
{
"Reshape"
};
op
::
Reshape
::
Reshape
(
const
Output
<
Node
>&
arg
,
const
AxisVector
&
input_order
,
const
Shape
&
output_shape
)
:
Op
(
"Reshape"
,
check_single_output_args
({
arg
})
)
:
Op
(
{
arg
}
)
,
m_input_order
(
input_order
)
,
m_output_shape
(
output_shape
)
{
...
...
src/ngraph/op/reshape.hpp
View file @
6f1e728a
...
...
@@ -60,6 +60,11 @@ namespace ngraph
class
Reshape
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a reshape operation.
Reshape
()
=
default
;
/// \brief Constructs a reshape operation.
///
/// \param arg The tensor to be reshaped.
...
...
@@ -67,7 +72,7 @@ namespace ngraph
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape
(
const
std
::
shared_ptr
<
Node
>&
arg
,
Reshape
(
const
Output
<
Node
>&
arg
,
const
AxisVector
&
input_order
,
const
Shape
&
output_shape
);
...
...
@@ -78,15 +83,18 @@ namespace ngraph
/// \return The order in which to iterate over input axes.
const
AxisVector
&
get_input_order
()
const
{
return
m_input_order
;
}
void
set_input_order
(
const
AxisVector
&
input_order
)
{
m_input_order
=
input_order
;
}
/// \return The shape of the output tensor.
const
Shape
&
get_output_shape
()
const
{
return
m_output_shape
;
}
void
set_output_shape
(
const
Shape
&
output_shape
)
{
m_output_shape
=
output_shape
;
}
bool
get_is_transpose
()
const
{
return
m_is_transpose
;
}
void
set_is_transpose
(
bool
is_transpose
)
{
m_is_transpose
=
is_transpose
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
AxisVector
m_input_order
;
const
Shape
m_output_shape
;
AxisVector
m_input_order
;
Shape
m_output_shape
;
bool
m_is_transpose
{
false
};
};
}
...
...
src/ngraph/op/result.cpp
View file @
6f1e728a
...
...
@@ -24,8 +24,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Result
::
Result
(
const
shared_ptr
<
Node
>&
arg
,
bool
needs_default_layout
)
:
Op
(
"Result"
,
check_single_output_args
({
arg
}))
const
string
op
::
Result
::
type_name
{
"Result"
};
op
::
Result
::
Result
(
const
Output
<
Node
>&
arg
,
bool
needs_default_layout
)
:
Op
({
arg
})
,
m_needs_default_layout
(
needs_default_layout
)
{
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/result.hpp
View file @
6f1e728a
...
...
@@ -27,10 +27,15 @@ namespace ngraph
class
Result
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Allows a value to be used as a function result.
Result
()
=
default
;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result
(
const
std
::
shared_ptr
<
Node
>&
arg
,
bool
needs_default_layout
=
false
);
Result
(
const
Output
<
Node
>&
arg
,
bool
needs_default_layout
=
false
);
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/pattern/matcher.cpp
View file @
6f1e728a
...
...
@@ -298,9 +298,14 @@ namespace ngraph
if
(
graph_node
->
is_commutative
())
{
std
::
sort
(
begin
(
pattern_args
),
end
(
pattern_args
));
// TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
// TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
std
::
sort
(
begin
(
pattern_args
),
end
(
pattern_args
),
[](
const
std
::
shared_ptr
<
ngraph
::
Node
>&
n1
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
n2
)
{
return
n1
->
get_instance_id
()
<
n2
->
get_instance_id
();
});
do
{
NGRAPH_DEBUG
<<
pad
(
2
*
m_depth
)
<<
"Running a permutation for graph_node "
...
...
@@ -311,7 +316,13 @@ namespace ngraph
pattern_map
.
insert
(
begin
(
copy
),
end
(
copy
));
return
true
;
}
}
while
(
std
::
next_permutation
(
begin
(
pattern_args
),
end
(
pattern_args
)));
}
while
(
std
::
next_permutation
(
begin
(
pattern_args
),
end
(
pattern_args
),
[](
const
std
::
shared_ptr
<
ngraph
::
Node
>&
n1
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
n2
)
{
return
n1
->
get_instance_id
()
<
n2
->
get_instance_id
();
}));
}
else
{
...
...
src/ngraph/runtime/backend.cpp
View file @
6f1e728a
...
...
@@ -90,3 +90,9 @@ std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& input_strea
{
throw
runtime_error
(
"load opertion unimplemented."
);
}
bool
runtime
::
Backend
::
set_config
(
const
map
<
string
,
string
>&
config
,
string
&
error
)
{
error
=
"set_config not supported"
;
return
false
;
}
src/ngraph/runtime/backend.hpp
View file @
6f1e728a
...
...
@@ -139,4 +139,13 @@ public:
/// \param op_name is the name of the backend specific op
/// \returns a shared pointer to the op if found, else nullptr
virtual
std
::
shared_ptr
<
ngraph
::
Node
>
get_backend_op
(
const
std
::
string
&
op_name
,
...);
/// \brief Allows sending backend specific configuration. The map contains key, value pairs
/// specific to a particluar backend. The definition of these key, value pairs is
/// defined by each backend.
/// \param config The configuration map sent to the backend
/// \param error An error string describing any error encountered
/// \returns true if the configuration is supported, false otherwise. On false the error
/// parameter value is valid.
virtual
bool
set_config
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
config
,
std
::
string
&
error
);
};
src/ngraph/runtime/interpreter/int_backend.cpp
View file @
6f1e728a
...
...
@@ -105,3 +105,16 @@ std::shared_ptr<runtime::Executable> runtime::interpreter::INTBackend::load(istr
}
return
exec
;
}
bool
runtime
::
interpreter
::
INTBackend
::
set_config
(
const
map
<
string
,
string
>&
config
,
string
&
error
)
{
bool
rc
=
false
;
auto
it
=
config
.
find
(
"test_echo"
);
error
=
""
;
if
(
it
!=
config
.
end
())
{
error
=
it
->
second
;
rc
=
true
;
}
return
rc
;
}
src/ngraph/runtime/interpreter/int_backend.hpp
View file @
6f1e728a
...
...
@@ -58,6 +58,8 @@ public:
bool
is_supported
(
const
Node
&
node
)
const
override
;
bool
set_config
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
config
,
std
::
string
&
error
)
override
;
private
:
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
};
src/ngraph/serializer.cpp
View file @
6f1e728a
...
...
@@ -141,6 +141,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
...
...
@@ -1803,6 +1804,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
{
node
->
set_friendly_name
(
node_name
);
}
if
(
ngraph
::
get_provenance_enabled
())
{
std
::
vector
<
json
>
prov_js
=
node_js
.
at
(
"provenance_tags"
);
for
(
auto
prov_tag
:
prov_js
)
{
node
->
add_provenance_tag
(
prov_tag
);
}
}
m_node_map
[
node_name
]
=
node
;
}
catch
(...)
...
...
@@ -1914,6 +1923,15 @@ json JSONSerializer::serialize_node(const Node& n)
}
node
[
"output_shapes"
]
=
output_shapes
;
}
if
(
ngraph
::
get_provenance_enabled
())
{
json
provenance_tags
=
json
::
array
();
for
(
auto
prov_tag
:
n
.
get_provenance_tags
())
{
provenance_tags
.
push_back
(
prov_tag
);
}
node
[
"provenance_tags"
]
=
provenance_tags
;
}
string
node_op
=
n
.
description
();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...
...
test/backend_api.cpp
View file @
6f1e728a
...
...
@@ -37,6 +37,28 @@ TEST(backend_api, invalid_name)
ASSERT_ANY_THROW
(
ngraph
::
runtime
::
Backend
::
create
(
"COMPLETELY-BOGUS-NAME"
));
}
TEST
(
backend_api
,
config
)
{
auto
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
string
error
;
string
message
=
"hello"
;
map
<
string
,
string
>
config
=
{{
"test_echo"
,
message
}};
EXPECT_TRUE
(
backend
->
set_config
(
config
,
error
));
EXPECT_STREQ
(
error
.
c_str
(),
message
.
c_str
());
EXPECT_FALSE
(
backend
->
set_config
({},
error
));
EXPECT_STREQ
(
error
.
c_str
(),
""
);
}
TEST
(
backend_api
,
config_unsupported
)
{
auto
backend
=
runtime
::
Backend
::
create
(
"NOP"
);
string
error
;
string
message
=
"hello"
;
map
<
string
,
string
>
config
=
{{
"test_echo"
,
message
}};
EXPECT_FALSE
(
backend
->
set_config
(
config
,
error
));
EXPECT_FALSE
(
error
==
""
);
}
#ifndef NGRAPH_JSON_DISABLE
TEST
(
backend_api
,
save_load
)
{
...
...
test/distributed.in.cpp
View file @
6f1e728a
...
...
@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type)
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
reduce_type
.
get_type
()
)
switch
(
reduce_type
)
{
case
reduction
:
:
Type
_t
::
sum
:
case
reduction
:
:
Type
::
SUM
:
copy_data
(
a
,
v
);
std
::
transform
(
v
.
begin
(),
v
.
end
(),
v
.
begin
(),
std
::
bind1st
(
std
::
multiplies
<
float
>
(),
comm_size
));
break
;
case
reduction
:
:
Type
_t
::
prod
:
case
reduction
:
:
Type
::
PROD
:
copy_data
(
a
,
v
);
std
::
transform
(
v
.
begin
(),
v
.
end
(),
v
.
begin
(),
[
&
](
float
elm
)
->
float
{
return
pow
(
elm
,
comm_size
);
});
break
;
case
reduction
:
:
Type
_t
::
min
:
case
reduction
:
:
Type
_t
::
max
:
case
reduction
:
:
Type
::
MIN
:
case
reduction
:
:
Type
::
MAX
:
auto
shift
=
get_distributed_interface
()
->
get_rank
();
std
::
rotate
(
v
.
begin
(),
v
.
begin
()
+
shift
%
v
.
size
(),
v
.
end
());
copy_data
(
a
,
v
);
if
(
reduce_type
==
reduction
::
Type
_t
::
min
)
if
(
reduce_type
==
reduction
::
Type
::
MIN
)
{
std
::
fill
(
v
.
begin
(),
v
.
end
(),
1
);
for
(
int
i
=
1
;
i
<
static_cast
<
int
>
(
v
.
size
())
-
comm_size
+
1
;
i
++
)
...
...
@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type)
TEST
(
distributed_
$
{
BACKEND_NAME
},
allreduce_sum
)
{
test_allreduce_common
(
reduction
::
sum
);
test_allreduce_common
(
reduction
::
Type
::
SUM
);
}
TEST
(
distributed_
$
{
BACKEND_NAME
},
allreduce_min
)
{
test_allreduce_common
(
reduction
::
min
);
test_allreduce_common
(
reduction
::
Type
::
MIN
);
}
TEST
(
distributed_
$
{
BACKEND_NAME
},
allreduce_max
)
{
test_allreduce_common
(
reduction
::
max
);
test_allreduce_common
(
reduction
::
Type
::
MAX
);
}
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST
(
distributed_
$
{
BACKEND_NAME
},
allreduce_prod
)
{
test_allreduce_common
(
reduction
::
prod
);
test_allreduce_common
(
reduction
::
Type
::
PROD
);
}
#endif
...
...
test/pattern.cpp
View file @
6f1e728a
...
...
@@ -514,6 +514,33 @@ TEST(pattern, previous_matches)
}
}
TEST
(
pattern
,
test_sort
)
{
using
ngraph
::
pattern
::
Matcher
;
Shape
shape
{};
auto
a
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
b
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
abs1
=
make_shared
<
op
::
Abs
>
(
a
);
auto
abs2
=
make_shared
<
op
::
Abs
>
(
b
);
auto
add
=
abs1
+
abs2
;
auto
pa
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
pb
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
pabs1
=
make_shared
<
op
::
Abs
>
(
pa
);
auto
pabs1_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
pabs1
);
auto
pabs2
=
make_shared
<
op
::
Abs
>
(
b
);
auto
padd
=
pabs1_label
+
pabs2
;
{
Matcher
n1
(
padd
);
ASSERT_TRUE
(
n1
.
match
(
add
));
auto
r1
=
n1
.
get_pattern_map
()[
pabs1_label
];
ASSERT_TRUE
(
n1
.
match
(
add
));
ASSERT_EQ
(
r1
,
n1
.
get_pattern_map
()[
pabs1_label
]);
}
}
TEST
(
pattern
,
recurrent_pattern
)
{
using
ngraph
::
pattern
::
RecurrentMatcher
;
...
...
test/util/autodiff/backprop_derivative.hpp
View file @
6f1e728a
...
...
@@ -90,20 +90,7 @@ namespace ngraph
auto
c_vec
=
read_vector
<
T
>
(
c_arg
);
fill
(
c_vec
.
begin
(),
c_vec
.
end
(),
static_cast
<
T
>
(
0
));
static
std
::
unordered_map
<
std
::
shared_ptr
<
Function
>
,
std
::
shared_ptr
<
runtime
::
Executable
>>
s_compiled_functions
;
auto
it
=
s_compiled_functions
.
find
(
df
);
std
::
shared_ptr
<
runtime
::
Executable
>
df_handle
;
if
(
it
==
s_compiled_functions
.
end
())
{
df_handle
=
backend
->
compile
(
df
);
s_compiled_functions
.
insert
({
df
,
df_handle
});
}
else
{
df_handle
=
it
->
second
;
}
auto
df_handle
=
backend
->
compile
(
df
);
// for each element of the adjoint
// same as saying for each element of y
...
...
@@ -212,20 +199,7 @@ namespace ngraph
s_clone_fwd_map
[
f
]
=
clone_function
(
*
fprop_cache
.
fprop
);
}
auto
clone_fwd
=
s_clone_fwd_map
[
f
];
static
std
::
unordered_map
<
std
::
shared_ptr
<
Function
>
,
std
::
shared_ptr
<
runtime
::
Executable
>>
s_compiled_functions
;
auto
it
=
s_compiled_functions
.
find
(
clone_fwd
);
std
::
shared_ptr
<
runtime
::
Executable
>
clone_fwd_handle
;
if
(
it
==
s_compiled_functions
.
end
())
{
clone_fwd_handle
=
backend
->
compile
(
clone_fwd
);
s_compiled_functions
.
insert
({
clone_fwd
,
clone_fwd_handle
});
}
else
{
clone_fwd_handle
=
it
->
second
;
}
auto
clone_fwd_handle
=
backend
->
compile
(
clone_fwd
);
clone_fwd_handle
->
call_with_validate
(
mod_f_output_args
,
f_input_args
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment