Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
2243df56
Commit
2243df56
authored
Jul 23, 2019
by
nmostafa
Browse files
Options
Browse Files
Download
Plain Diff
Merge remote-tracking branch 'upstream/master' into nmostafa/gather
parents
443cbc8a
607445a4
Hide whitespace changes
Inline
Side-by-side
Showing
54 changed files
with
356 additions
and
223 deletions
+356
-223
CMakeLists.txt
CMakeLists.txt
+2
-2
util.py
python/test/ngraph/util.py
+13
-11
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+1
-2
mlir_subgraph_extraction.cpp
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
+0
-2
check.hpp
src/ngraph/check.hpp
+1
-1
node.hpp
src/ngraph/node.hpp
+1
-1
add.cpp
src/ngraph/op/add.cpp
+1
-1
add.hpp
src/ngraph/op/add.hpp
+2
-3
and.hpp
src/ngraph/op/and.hpp
+1
-2
batch_norm.cpp
src/ngraph/op/batch_norm.cpp
+49
-48
batch_norm.hpp
src/ngraph/op/batch_norm.hpp
+28
-29
divide.cpp
src/ngraph/op/divide.cpp
+1
-1
divide.hpp
src/ngraph/op/divide.hpp
+1
-2
dot.hpp
src/ngraph/op/dot.hpp
+1
-1
equal.hpp
src/ngraph/op/equal.hpp
+2
-0
gather.cpp
src/ngraph/op/gather.cpp
+2
-0
gather.hpp
src/ngraph/op/gather.hpp
+7
-4
gather_nd.cpp
src/ngraph/op/gather_nd.cpp
+2
-0
gather_nd.hpp
src/ngraph/op/gather_nd.hpp
+6
-2
greater.cpp
src/ngraph/op/greater.cpp
+5
-3
greater.hpp
src/ngraph/op/greater.hpp
+7
-2
greater_eq.cpp
src/ngraph/op/greater_eq.cpp
+5
-3
greater_eq.hpp
src/ngraph/op/greater_eq.hpp
+7
-2
less.cpp
src/ngraph/op/less.cpp
+4
-4
less.hpp
src/ngraph/op/less.hpp
+7
-2
less_eq.cpp
src/ngraph/op/less_eq.cpp
+5
-3
less_eq.hpp
src/ngraph/op/less_eq.hpp
+7
-2
log.cpp
src/ngraph/op/log.cpp
+4
-2
log.hpp
src/ngraph/op/log.hpp
+6
-1
lrn.cpp
src/ngraph/op/lrn.cpp
+5
-3
lrn.hpp
src/ngraph/op/lrn.hpp
+10
-5
max.cpp
src/ngraph/op/max.cpp
+0
-4
max.hpp
src/ngraph/op/max.hpp
+1
-1
max_pool.cpp
src/ngraph/op/max_pool.cpp
+17
-13
max_pool.hpp
src/ngraph/op/max_pool.hpp
+37
-10
maximum.cpp
src/ngraph/op/maximum.cpp
+5
-3
maximum.hpp
src/ngraph/op/maximum.hpp
+8
-3
min.cpp
src/ngraph/op/min.cpp
+0
-4
min.hpp
src/ngraph/op/min.hpp
+1
-1
minimum.cpp
src/ngraph/op/minimum.cpp
+5
-3
minimum.hpp
src/ngraph/op/minimum.hpp
+8
-2
multiply.cpp
src/ngraph/op/multiply.cpp
+6
-4
multiply.hpp
src/ngraph/op/multiply.hpp
+9
-5
negative.cpp
src/ngraph/op/negative.cpp
+5
-3
negative.hpp
src/ngraph/op/negative.hpp
+8
-2
not.cpp
src/ngraph/op/not.cpp
+4
-2
not.hpp
src/ngraph/op/not.hpp
+6
-1
not_equal.cpp
src/ngraph/op/not_equal.cpp
+5
-3
not_equal.hpp
src/ngraph/op/not_equal.hpp
+9
-2
one_hot.cpp
src/ngraph/op/one_hot.cpp
+4
-2
one_hot.hpp
src/ngraph/op/one_hot.hpp
+7
-3
or.cpp
src/ngraph/op/or.cpp
+4
-4
or.hpp
src/ngraph/op/or.hpp
+6
-4
CMakeLists.txt
src/ngraph/runtime/plaidml/CMakeLists.txt
+8
-0
No files found.
CMakeLists.txt
View file @
2243df56
...
...
@@ -306,12 +306,12 @@ set(NGRAPH_INSTALL_DOC "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_DOCDIR}")
set
(
NGRAPH_INSTALL_BIN
"
${
CMAKE_INSTALL_PREFIX
}
/
${
CMAKE_INSTALL_BINDIR
}
"
)
if
(
LINUX
)
if
(
DEFINED NGRAPH_RPATH
)
set
(
CMAKE_BUILD_RPATH
"$ORIGIN:
${
NGRAPH_RPATH
}
"
)
set
(
CMAKE_INSTALL_RPATH
"$ORIGIN:
${
NGRAPH_RPATH
}
"
)
else
()
set
(
CMAKE_BUILD_RPATH
"$ORIGIN"
)
set
(
CMAKE_INSTALL_RPATH
"$ORIGIN"
)
endif
()
set
(
CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE
)
set
(
CMAKE_BUILD_WITH_INSTALL_RPATH TRUE
)
endif
()
#-----------------------------------------------------------------------------------------------
...
...
python/test/ngraph/util.py
View file @
2243df56
...
...
@@ -16,9 +16,8 @@
import
numpy
as
np
import
ngraph
as
ng
from
string
import
ascii_uppercase
from
ngraph.utils.types
import
NumericData
from
typing
import
Any
,
Callable
,
List
import
test
...
...
@@ -32,10 +31,14 @@ def get_runtime():
def
run_op_node
(
input_data
,
op_fun
,
*
args
):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
...
...
@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args
=
[]
op_fun_args
=
[]
comp_inputs
=
[]
for
idx
,
data
in
enumerate
(
input_data
):
if
np
.
isscalar
(
data
):
op_fun_args
.
append
(
ng
.
constant
(
data
,
_get_numpy_dtype
(
data
)))
else
:
node
=
ng
.
parameter
(
data
.
shape
,
name
=
ascii_uppercase
[
idx
],
dtype
=
data
.
dtype
)
op_fun_args
.
append
(
node
)
comp_args
.
append
(
node
)
comp_inputs
.
append
(
data
)
for
data
in
input_data
:
op_fun_args
.
append
(
ng
.
constant
(
data
,
_get_numpy_dtype
(
data
)))
op_fun_args
.
extend
(
args
)
node
=
op_fun
(
*
op_fun_args
)
computation
=
runtime
.
computation
(
node
,
*
comp_args
)
...
...
@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def
run_op_numeric_data
(
input_data
,
op_fun
,
*
args
):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
...
...
src/contrib/mlir/lowerer.cpp
View file @
2243df56
...
...
@@ -350,6 +350,7 @@ namespace
}
return
callBackFuncPtr
;
}
// NGDialect converters
Type
NGraphTypeConverter
::
convertType
(
Type
type
)
{
...
...
@@ -576,7 +577,6 @@ namespace
// Create Value for result, and extract type info.
Value
*
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
NGRAPH_CHECK
(
result
,
"Unexpected null result in ConcatOp"
);
auto
resultTy
=
result
->
getType
().
cast
<
MemRefType
>
();
// Create view to write into result.
MemRefView
vRes
(
result
);
...
...
@@ -590,7 +590,6 @@ namespace
for
(
auto
&
operand
:
operands
)
{
NGRAPH_CHECK
(
operand
,
"Unexpected null operand in ConcatOp"
);
auto
operandTy
=
result
->
getType
().
cast
<
MemRefType
>
();
// Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating
// loops of this form:
...
...
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
View file @
2243df56
...
...
@@ -75,7 +75,6 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
// Associate nodes of second sub-graph to first one
auto
sg_nodes
=
sg2
.
get_nodes
();
auto
&
node_map
=
m_pass
.
m_node_to_graph
;
for
(
auto
node
:
sg_nodes
)
{
NGRAPH_DEBUG
<<
*
node
;
...
...
@@ -113,7 +112,6 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for
(
auto
op
:
func
->
get_ordered_ops
())
{
NodeVector
inputs
;
int
first_graph_id
=
-
1
;
std
::
unordered_set
<
int
>
subgraph_ids
;
// unsupported ops, skip
if
(
!
is_supported_mlir_op
(
op
))
...
...
src/ngraph/check.hpp
View file @
2243df56
...
...
@@ -160,5 +160,5 @@ namespace ngraph
/// \brief Macro to signal a code path that is unreachable in a successful execution. It's
/// implemented with NGRAPH_CHECK macro.
/// \param ... Additional error message that should describe why that execution path is unreachable.
/// \throws ::ngrap::CheckFailure if the macro is executed.
/// \throws ::ngrap
h
::CheckFailure if the macro is executed.
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", ##__VA_ARGS__)
src/ngraph/node.hpp
View file @
2243df56
...
...
@@ -214,7 +214,7 @@ namespace ngraph
virtual
bool
is_constant
()
const
;
virtual
bool
is_null
()
const
{
return
false
;
}
virtual
bool
is_op
()
const
{
return
false
;
}
virtual
bool
is_commutative
()
{
return
false
;
}
virtual
bool
is_commutative
()
const
{
return
false
;
}
virtual
bool
is_dynamic
()
const
;
virtual
bool
has_state
()
const
{
return
false
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
...
...
src/ngraph/op/add.cpp
View file @
2243df56
...
...
@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints
.
add_delta
(
y
,
delta
);
}
shared_ptr
<
Node
>
ngraph
::
operator
+
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
)
shared_ptr
<
Node
>
ngraph
::
operator
+
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
)
{
return
make_shared
<
op
::
Add
>
(
arg0
,
arg1
);
}
src/ngraph/op/add.hpp
View file @
2243df56
...
...
@@ -51,13 +51,12 @@ namespace ngraph
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
virtual
bool
is_commutative
()
override
{
return
true
;
}
};
}
std
::
shared_ptr
<
ngraph
::
Node
>
operator
+
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg0
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
arg1
);
std
::
shared_ptr
<
Node
>
operator
+
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
);
}
src/ngraph/op/and.hpp
View file @
2243df56
...
...
@@ -51,8 +51,7 @@ namespace ngraph
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
virtual
bool
is_commutative
()
override
{
return
true
;
}
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
};
}
}
src/ngraph/op/batch_norm.cpp
View file @
2243df56
...
...
@@ -22,12 +22,15 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
const
std
::
string
ngraph
::
op
::
BatchNormTraining
::
type_name
{
"BatchNormTraining"
};
using
namespace
std
;
using
namespace
ngraph
;
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
double
epsilon
)
const
string
op
::
BatchNormTraining
::
type_name
{
"BatchNormTraining"
};
op
::
BatchNormTraining
::
BatchNormTraining
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
double
epsilon
)
:
Op
({
gamma
,
beta
,
input
})
,
m_epsilon
(
epsilon
)
{
...
...
@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
}
// DEPRECATED
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
double
eps
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
)
op
::
BatchNormTraining
::
BatchNormTraining
(
double
eps
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
input
)
:
Op
({
gamma
,
beta
,
input
})
,
m_epsilon
(
eps
)
{
constructor_validate_and_infer_types
();
}
void
ngraph
::
op
::
BatchNormTraining
::
validate_and_infer_types
()
void
op
::
BatchNormTraining
::
validate_and_infer_types
()
{
element
::
Type
result_et
;
PartialShape
result_batch_shape
;
...
...
@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type
(
2
,
result_et
,
result_channel_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
ngraph
::
op
::
BatchNormTraining
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
BatchNormTraining
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
std
::
make_shared
<
BatchNormTraining
>
(
new_args
.
at
(
2
),
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_epsilon
);
}
void
ngraph
::
op
::
BatchNormTraining
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
void
op
::
BatchNormTraining
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
auto
gamma
=
input
(
0
).
get_source_output
();
auto
beta
=
input
(
1
).
get_source_output
();
...
...
@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints
.
add_delta
(
beta
,
dbeta
);
}
const
st
d
::
string
ngraph
::
op
::
BatchNormInference
::
type_name
{
"BatchNormInference"
};
const
st
ring
op
::
BatchNormInference
::
type_name
{
"BatchNormInference"
};
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
double
epsilon
)
op
::
BatchNormInference
::
BatchNormInference
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
,
double
epsilon
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
})
,
m_epsilon
(
epsilon
)
{
...
...
@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
}
// DEPRECATED
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
double
eps
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
)
op
::
BatchNormInference
::
BatchNormInference
(
double
eps
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
})
,
m_epsilon
(
eps
)
{
constructor_validate_and_infer_types
();
}
void
ngraph
::
op
::
BatchNormInference
::
validate_and_infer_types
()
void
op
::
BatchNormInference
::
validate_and_infer_types
()
{
element
::
Type
result_et
;
PartialShape
result_batch_shape
;
...
...
@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_batch_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
ngraph
::
op
::
BatchNormInference
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
BatchNormInference
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
std
::
make_shared
<
BatchNormInference
>
(
new_args
.
at
(
2
),
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
m_epsilon
);
}
const
st
d
::
string
ngraph
::
op
::
BatchNormTrainingBackprop
::
type_name
{
"BatchNormTrainingBackprop"
};
const
st
ring
op
::
BatchNormTrainingBackprop
::
type_name
{
"BatchNormTrainingBackprop"
};
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
delta
,
double
epsilon
)
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
,
const
Output
<
Node
>&
delta
,
double
epsilon
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
})
,
m_epsilon
(
epsilon
)
...
...
@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types
();
}
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
double
epsilon
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
delta
)
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
double
epsilon
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
,
const
Output
<
Node
>&
delta
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
})
,
m_epsilon
(
epsilon
)
...
...
@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types
();
}
void
ngraph
::
op
::
BatchNormTrainingBackprop
::
validate_and_infer_types
()
void
op
::
BatchNormTrainingBackprop
::
validate_and_infer_types
()
{
PartialShape
input_and_delta_shape
{
get_input_partial_shape
(
INPUT_DATA
)};
...
...
@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type
(
2
,
result_et
,
result_channel_shape
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
ngraph
::
op
::
BatchNormTrainingBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
BatchNormTrainingBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
std
::
make_shared
<
op
::
BatchNormTrainingBackprop
>
(
new_args
.
at
(
2
),
...
...
src/ngraph/op/batch_norm.hpp
View file @
2243df56
...
...
@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining
(
Output
<
Node
>
input
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
BatchNormTraining
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
double
epsilon
);
NGRAPH_DEPRECATED_DOC
...
...
@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormTraining
(
double
eps
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
input
);
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
input
);
void
validate_and_infer_types
()
override
;
...
...
@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
BatchNormInference
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
,
double
epsilon
);
NGRAPH_DEPRECATED_DOC
...
...
@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormInference
(
double
eps
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
);
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
);
void
validate_and_infer_types
()
override
;
...
...
@@ -165,24 +165,23 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BatchNormTrainingBackprop
()
=
default
;
BatchNormTrainingBackprop
(
Output
<
Node
>
input
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
mean
,
Output
<
Node
>
variance
,
Output
<
Node
>
delta
,
BatchNormTrainingBackprop
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
,
const
Output
<
Node
>&
delta
,
double
epsilon
);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormTrainingBackprop
(
double
epsilon
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
input
,
Output
<
Node
>
mean
,
Output
<
Node
>
variance
,
Output
<
Node
>
delta
);
const
Output
<
Node
>&
gamma
,
const
Output
<
Node
>&
beta
,
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
mean
,
const
Output
<
Node
>&
variance
,
const
Output
<
Node
>&
delta
);
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/divide.cpp
View file @
2243df56
...
...
@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints
.
add_delta
(
y
,
-
delta
*
shared_from_this
()
/
y
);
}
shared_ptr
<
Node
>
ngraph
::
operator
/
(
const
Output
<
Node
>
arg0
,
const
Output
<
Node
>
arg1
)
shared_ptr
<
Node
>
ngraph
::
operator
/
(
const
Output
<
Node
>
&
arg0
,
const
Output
<
Node
>&
arg1
)
{
return
make_shared
<
op
::
Divide
>
(
arg0
,
arg1
);
}
src/ngraph/op/divide.hpp
View file @
2243df56
...
...
@@ -64,6 +64,5 @@ namespace ngraph
};
}
std
::
shared_ptr
<
ngraph
::
Node
>
operator
/
(
const
Output
<
ngraph
::
Node
>
arg0
,
const
Output
<
ngraph
::
Node
>
arg1
);
std
::
shared_ptr
<
Node
>
operator
/
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
);
}
src/ngraph/op/dot.hpp
View file @
2243df56
...
...
@@ -58,7 +58,7 @@ namespace ngraph
void
validate_and_infer_types
()
override
;
size_t
get_reduction_axes_count
()
const
{
return
m_reduction_axes_count
;
}
void
g
et_reduction_axes_count
(
size_t
reduction_axes_count
)
void
s
et_reduction_axes_count
(
size_t
reduction_axes_count
)
{
m_reduction_axes_count
=
reduction_axes_count
;
}
...
...
src/ngraph/op/equal.hpp
View file @
2243df56
...
...
@@ -56,6 +56,8 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
};
}
}
src/ngraph/op/gather.cpp
View file @
2243df56
...
...
@@ -23,6 +23,8 @@ using namespace ngraph;
static
int
PARAMS
=
0
;
static
int
INDICES
=
1
;
const
string
op
::
Gather
::
type_name
{
"Gather"
};
shared_ptr
<
Node
>
op
::
Gather
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/gather.hpp
View file @
2243df56
...
...
@@ -26,13 +26,15 @@ namespace ngraph
class
Gather
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Gather
()
=
default
;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather
Gather
(
const
std
::
shared_ptr
<
Node
>&
params
,
const
std
::
shared_ptr
<
Node
>&
indices
,
size_t
axis
=
0
)
:
Op
(
"Gather"
,
check_single_output_args
({
params
,
indices
}))
Gather
(
const
Output
<
Node
>&
params
,
const
Output
<
Node
>&
indices
,
size_t
axis
=
0
)
:
Op
({
params
,
indices
})
,
m_axis
(
axis
)
{
constructor_validate_and_infer_types
();
...
...
@@ -46,6 +48,7 @@ namespace ngraph
}
size_t
get_axis
()
const
{
return
m_axis
;
}
void
set_axis
(
size_t
axis
)
{
m_axis
=
axis
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/gather_nd.cpp
View file @
2243df56
...
...
@@ -23,6 +23,8 @@ using namespace ngraph;
static
int
PARAMS
=
0
;
static
int
INDICES
=
1
;
const
string
op
::
GatherND
::
type_name
{
"GatherND"
};
shared_ptr
<
Node
>
op
::
GatherND
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/gather_nd.hpp
View file @
2243df56
...
...
@@ -26,10 +26,14 @@ namespace ngraph
class
GatherND
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
GatherND
()
=
default
;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND
(
const
std
::
shared_ptr
<
Node
>&
params
,
const
std
::
shared_ptr
<
Node
>&
indices
)
:
Op
(
"GatherND"
,
check_single_output_args
({
params
,
indices
})
)
GatherND
(
const
Output
<
Node
>&
params
,
const
Output
<
Node
>&
indices
)
:
Op
(
{
params
,
indices
}
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/greater.cpp
View file @
2243df56
...
...
@@ -19,10 +19,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Greater
::
Greater
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
Greater
::
type_name
{
"Greater"
};
op
::
Greater
::
Greater
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"Greater"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/greater.hpp
View file @
2243df56
...
...
@@ -26,13 +26,18 @@ namespace ngraph
class
Greater
:
public
util
::
BinaryElementwiseComparison
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a greater-than operation.
Greater
()
=
default
;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Greater
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Greater
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/greater_eq.cpp
View file @
2243df56
...
...
@@ -19,10 +19,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
GreaterEq
::
GreaterEq
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
GreaterEq
::
type_name
{
"GreaterEq"
};
op
::
GreaterEq
::
GreaterEq
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"GreaterEq"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/greater_eq.hpp
View file @
2243df56
...
...
@@ -26,13 +26,18 @@ namespace ngraph
class
GreaterEq
:
public
util
::
BinaryElementwiseComparison
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq
()
=
default
;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
GreaterEq
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
GreaterEq
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/less.cpp
View file @
2243df56
...
...
@@ -19,10 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Less
::
Less
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"Less"
,
arg0
,
arg1
,
autob
)
const
string
op
::
Less
::
type_name
{
"Less"
};
op
::
Less
::
Less
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/less.hpp
View file @
2243df56
...
...
@@ -26,13 +26,18 @@ namespace ngraph
class
Less
:
public
util
::
BinaryElementwiseComparison
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a less-than operation.
Less
()
=
default
;
/// \brief Constructs a less-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Less
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Less
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/less_eq.cpp
View file @
2243df56
...
...
@@ -19,10 +19,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
LessEq
::
LessEq
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
LessEq
::
type_name
{
"LessEq"
};
op
::
LessEq
::
LessEq
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"LessEq"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/less_eq.hpp
View file @
2243df56
...
...
@@ -26,13 +26,18 @@ namespace ngraph
class
LessEq
:
public
util
::
BinaryElementwiseComparison
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a less-than-or-equal operation.
LessEq
()
=
default
;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
LessEq
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
LessEq
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/log.cpp
View file @
2243df56
...
...
@@ -20,8 +20,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Log
::
Log
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Log"
,
arg
)
const
string
op
::
Log
::
type_name
{
"Log"
};
op
::
Log
::
Log
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/log.hpp
View file @
2243df56
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Log
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a natural log operation.
Log
()
=
default
;
/// \brief Constructs a natural log operation.
///
/// \param arg Node that produces the input tensor.
Log
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Log
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/lrn.cpp
View file @
2243df56
...
...
@@ -20,12 +20,14 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
LRN
::
LRN
(
const
std
::
shared_ptr
<
Node
>&
arg
,
double
alpha
,
double
beta
,
double
bias
,
size_t
nsize
)
:
UnaryElementwiseArithmetic
(
"LRN"
,
arg
)
const
string
op
::
LRN
::
type_name
{
"LRN"
};
op
::
LRN
::
LRN
(
const
Output
<
Node
>&
arg
,
double
alpha
,
double
beta
,
double
bias
,
size_t
size
)
:
UnaryElementwiseArithmetic
(
arg
)
,
m_alpha
(
alpha
)
,
m_beta
(
beta
)
,
m_bias
(
bias
)
,
m_size
(
n
size
)
,
m_size
(
size
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/lrn.hpp
View file @
2243df56
...
...
@@ -38,23 +38,28 @@ namespace ngraph
class
LRN
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a LRN operation.
LRN
()
=
default
;
/// \brief Constructs a LRN operation.
///
/// \param arg Node that produces the input tensor.
LRN
(
const
std
::
shared_ptr
<
Node
>&
arg
,
double
alpha
,
double
beta
,
double
bias
,
size_t
size
);
LRN
(
const
Output
<
Node
>&
arg
,
double
alpha
,
double
beta
,
double
bias
,
size_t
size
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
validate_and_infer_types
()
override
;
double
get_alpha
()
const
{
return
m_alpha
;
}
void
set_alpha
(
double
alpha
)
{
m_alpha
=
alpha
;
}
double
get_beta
()
const
{
return
m_beta
;
}
void
set_beta
(
double
beta
)
{
m_beta
=
beta
;
}
double
get_bias
()
const
{
return
m_bias
;
}
void
set_bias
(
double
bias
)
{
m_bias
=
bias
;
}
size_t
get_nsize
()
const
{
return
m_size
;
}
void
set_nsize
(
size_t
size
)
{
m_size
=
size
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
...
...
src/ngraph/op/max.cpp
View file @
2243df56
...
...
@@ -22,10 +22,6 @@ using namespace ngraph;
const
string
op
::
Max
::
type_name
{
"Max"
};
op
::
Max
::
Max
()
{
}
op
::
Max
::
Max
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
ArithmeticReduction
(
arg
,
reduction_axes
)
{
...
...
src/ngraph/op/max.hpp
View file @
2243df56
...
...
@@ -30,7 +30,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a "max" reduction operation.
Max
();
Max
()
=
default
;
/// \brief Constructs a max-reduction operation.
///
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/max_pool.cpp
View file @
2243df56
...
...
@@ -25,14 +25,16 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
MaxPool
::
MaxPool
(
const
shared_ptr
<
Node
>&
arg
,
const
string
op
::
MaxPool
::
type_name
{
"MaxPool"
};
op
::
MaxPool
::
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
const
Shape
&
padding_above
,
const
PadType
&
pad_type
,
bool
ceil_mode
)
:
Op
(
"MaxPool"
,
check_single_output_args
({
arg
})
)
:
Op
(
{
arg
}
)
,
m_window_shape
(
window_shape
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_padding_below
(
padding_below
)
...
...
@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types
();
}
op
::
MaxPool
::
MaxPool
(
const
shared_ptr
<
Node
>&
arg
,
op
::
MaxPool
::
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
...
...
@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
{
}
op
::
MaxPool
::
MaxPool
(
const
shared_ptr
<
Node
>&
arg
,
op
::
MaxPool
::
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
...
...
@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode
));
}
op
::
MaxPool
::
MaxPool
(
const
shared_ptr
<
Node
>&
arg
,
op
::
MaxPool
::
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
)
:
MaxPool
(
arg
,
window_shape
,
window_movement_strides
,
Shape
(),
Shape
())
{
}
op
::
MaxPool
::
MaxPool
(
const
shared_ptr
<
Node
>&
arg
,
const
Shape
&
window_shape
)
op
::
MaxPool
::
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
)
:
MaxPool
(
arg
,
window_shape
,
Strides
(),
Shape
(),
Shape
())
{
}
...
...
@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode
);
}
op
::
MaxPoolBackprop
::
MaxPoolBackprop
(
const
shared_ptr
<
Node
>&
arg_forward
,
const
shared_ptr
<
Node
>&
delta
,
const
string
op
::
MaxPoolBackprop
::
type_name
{
"MaxPoolBackprop"
};
op
::
MaxPoolBackprop
::
MaxPoolBackprop
(
const
Output
<
Node
>&
arg_forward
,
const
Output
<
Node
>&
delta
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
const
Shape
&
padding_above
)
:
Op
(
"MaxPoolBackprop"
,
check_single_output_args
({
arg_forward
,
delta
})
)
:
Op
(
{
arg_forward
,
delta
}
)
,
m_window_shape
(
window_shape
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_padding_below
(
padding_below
)
...
...
@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
constructor_validate_and_infer_types
();
}
op
::
MaxPoolBackprop
::
MaxPoolBackprop
(
const
shared_ptr
<
Node
>&
arg_forward
,
const
shared_ptr
<
Node
>&
delta
,
const
shared_ptr
<
Node
>&
result_forward
,
op
::
MaxPoolBackprop
::
MaxPoolBackprop
(
const
Output
<
Node
>&
arg_forward
,
const
Output
<
Node
>&
delta
,
const
Output
<
Node
>&
result_forward
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
const
Shape
&
padding_above
)
:
Op
(
"MaxPoolBackprop"
,
check_single_output_args
({
arg_forward
,
delta
,
result_forward
})
)
:
Op
(
{
arg_forward
,
delta
,
result_forward
}
)
,
m_window_shape
(
window_shape
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_padding_below
(
padding_below
)
...
...
src/ngraph/op/max_pool.hpp
View file @
2243df56
...
...
@@ -28,6 +28,12 @@ namespace ngraph
class
MaxPool
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a batched max pooling operation.
MaxPool
()
=
default
;
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
...
...
@@ -37,7 +43,7 @@ namespace ngraph
/// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool
(
const
std
::
shared_ptr
<
Node
>&
arg
,
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
...
...
@@ -53,7 +59,7 @@ namespace ngraph
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes
MaxPool
(
const
std
::
shared_ptr
<
Node
>&
arg
,
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
...
...
@@ -67,7 +73,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
MaxPool
(
const
std
::
shared_ptr
<
Node
>&
arg
,
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
...
...
@@ -80,7 +86,7 @@ namespace ngraph
/// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides.
MaxPool
(
const
std
::
shared_ptr
<
Node
>&
arg
,
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
);
...
...
@@ -88,23 +94,32 @@ namespace ngraph
///
/// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape.
MaxPool
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
Shape
&
window_shape
);
MaxPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
/// \return The window shape.
const
Shape
&
get_window_shape
()
const
{
return
m_window_shape
;
}
void
set_window_shape
(
const
Shape
&
window_shape
)
{
m_window_shape
=
window_shape
;
}
/// \return The window movement strides.
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
void
set_window_movement_strides
(
const
Strides
&
window_movement_strides
)
{
m_window_movement_strides
=
window_movement_strides
;
}
/// \return The below-padding shape.
const
Shape
&
get_padding_below
()
const
{
return
m_padding_below
;
}
void
set_padding_below
(
const
Shape
&
padding_below
)
{
m_padding_below
=
padding_below
;
}
/// \return The above-padding shape.
const
Shape
&
get_padding_above
()
const
{
return
m_padding_above
;
}
void
set_adding_above
(
const
Shape
&
padding_above
)
{
m_padding_above
=
padding_above
;
}
/// \return The pad type for pooling.
const
PadType
&
get_pad_type
()
const
{
return
m_pad_type
;
}
void
set_pad_type
(
const
PadType
&
pad_type
)
{
m_pad_type
=
pad_type
;
}
/// \return The ceiling mode being used for output shape computations
bool
get_ceil_mode
()
const
{
return
m_ceil_mode
;
}
void
set_ceil_mode
(
bool
ceil_mode
)
{
m_ceil_mode
=
ceil_mode
;
}
/// \return The default value for MaxPool.
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
override
{
...
...
@@ -126,16 +141,21 @@ namespace ngraph
class
MaxPoolBackprop
:
public
Op
{
public
:
MaxPoolBackprop
(
const
std
::
shared_ptr
<
Node
>&
arg_forward
,
const
std
::
shared_ptr
<
Node
>&
delta
,
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
MaxPoolBackprop
()
=
default
;
MaxPoolBackprop
(
const
Output
<
Node
>&
arg_forward
,
const
Output
<
Node
>&
delta
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
const
Shape
&
padding_above
);
MaxPoolBackprop
(
const
std
::
shared_ptr
<
Node
>&
arg_forward
,
const
std
::
shared_ptr
<
Node
>&
delta
,
const
std
::
shared_ptr
<
Node
>&
result_forward
,
MaxPoolBackprop
(
const
Output
<
Node
>&
arg_forward
,
const
Output
<
Node
>&
delta
,
const
Output
<
Node
>&
result_forward
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Shape
&
padding_below
,
...
...
@@ -147,9 +167,16 @@ namespace ngraph
void
validate_and_infer_types
()
override
;
const
Shape
&
get_window_shape
()
const
{
return
m_window_shape
;
}
void
set_window_shape
(
const
Shape
&
window_shape
)
{
m_window_shape
=
window_shape
;
}
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
void
set_window_movement_strides
(
const
Strides
&
window_movement_strides
)
{
m_window_movement_strides
=
window_movement_strides
;
}
const
Shape
&
get_padding_below
()
const
{
return
m_padding_below
;
}
void
set_padding_below
(
const
Shape
&
padding_below
)
{
m_padding_below
=
padding_below
;
}
const
Shape
&
get_padding_above
()
const
{
return
m_padding_above
;
}
void
set_padding_above
(
const
Shape
&
padding_above
)
{
m_padding_above
=
padding_above
;
}
protected
:
Shape
m_window_shape
;
Strides
m_window_movement_strides
;
...
...
src/ngraph/op/maximum.cpp
View file @
2243df56
...
...
@@ -25,10 +25,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Maximum
::
Maximum
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
Maximum
::
type_name
{
"Maximum"
};
op
::
Maximum
::
Maximum
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Maximum"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/maximum.hpp
View file @
2243df56
...
...
@@ -26,19 +26,24 @@ namespace ngraph
class
Maximum
:
public
util
::
BinaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a maximum operation.
Maximum
()
=
default
;
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Maximum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Maximum
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
bool
is_commutative
()
override
{
return
true
;
}
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
...
...
src/ngraph/op/min.cpp
View file @
2243df56
...
...
@@ -22,10 +22,6 @@ using namespace ngraph;
const
string
op
::
Min
::
type_name
{
"Min"
};
op
::
Min
::
Min
()
{
}
op
::
Min
::
Min
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
ArithmeticReduction
(
arg
,
reduction_axes
)
{
...
...
src/ngraph/op/min.hpp
View file @
2243df56
...
...
@@ -30,7 +30,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a "min" reduction operation.
Min
();
Min
()
=
default
;
/// \brief Constructs a min-reduction operation.
///
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/minimum.cpp
View file @
2243df56
...
...
@@ -25,10 +25,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Minimum
::
Minimum
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
Minimum
::
type_name
{
"Minimum"
};
op
::
Minimum
::
Minimum
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Minimum"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/minimum.hpp
View file @
2243df56
...
...
@@ -26,18 +26,24 @@ namespace ngraph
class
Minimum
:
public
util
::
BinaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a minimum operation.
Minimum
()
=
default
;
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Minimum
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Minimum
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
...
...
src/ngraph/op/multiply.cpp
View file @
2243df56
...
...
@@ -19,10 +19,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Multiply
::
Multiply
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
Multiply
::
type_name
{
"Multiply"
};
op
::
Multiply
::
Multiply
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Multiply"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints
.
add_delta
(
y
,
x
*
delta
);
}
shared_ptr
<
Node
>
ngraph
::
operator
*
(
const
shared_ptr
<
Node
>
arg0
,
const
shared_ptr
<
Node
>
arg1
)
shared_ptr
<
Node
>
ngraph
::
operator
*
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
)
{
return
make_shared
<
op
::
Multiply
>
(
arg0
,
arg1
);
}
src/ngraph/op/multiply.hpp
View file @
2243df56
...
...
@@ -26,25 +26,29 @@ namespace ngraph
class
Multiply
:
public
util
::
BinaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a multiplication operation.
Multiply
()
=
default
;
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Multiply
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Multiply
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
virtual
bool
is_commutative
()
override
{
return
true
;
}
};
};
std
::
shared_ptr
<
ngraph
::
Node
>
operator
*
(
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg0
,
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg1
);
std
::
shared_ptr
<
Node
>
operator
*
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
);
}
src/ngraph/op/negative.cpp
View file @
2243df56
...
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Negative
::
Negative
(
const
shared_ptr
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Negative"
,
arg
)
const
string
op
::
Negative
::
type_name
{
"Negative"
};
op
::
Negative
::
Negative
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
constructor_validate_and_infer_types
();
}
...
...
@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints
.
add_delta
(
x
,
-
delta
);
}
shared_ptr
<
Node
>
ngraph
::
operator
-
(
const
shared_ptr
<
Node
>
arg0
)
shared_ptr
<
Node
>
ngraph
::
operator
-
(
const
Output
<
Node
>&
arg0
)
{
return
make_shared
<
op
::
Negative
>
(
arg0
);
}
src/ngraph/op/negative.hpp
View file @
2243df56
...
...
@@ -26,17 +26,23 @@ namespace ngraph
class
Negative
:
public
util
::
UnaryElementwiseArithmetic
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a negative operation.
Negative
()
=
default
;
/// \brief Constructs a negative operation.
///
/// \param arg Node that produces the input tensor.
Negative
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Negative
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
};
}
std
::
shared_ptr
<
ngraph
::
Node
>
operator
-
(
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg0
);
std
::
shared_ptr
<
Node
>
operator
-
(
const
Output
<
Node
>&
arg0
);
}
src/ngraph/op/not.cpp
View file @
2243df56
...
...
@@ -20,8 +20,10 @@
using
namespace
ngraph
;
using
namespace
std
;
op
::
Not
::
Not
(
const
shared_ptr
<
Node
>&
arg
)
:
Op
(
"Not"
,
check_single_output_args
({
arg
}))
const
string
op
::
Not
::
type_name
{
"Not"
};
op
::
Not
::
Not
(
const
Output
<
Node
>&
arg
)
:
Op
({
arg
})
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/not.hpp
View file @
2243df56
...
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Not
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a logical negation operation.
Not
()
=
default
;
/// \brief Constructs a logical negation operation.
///
/// \param arg Node that produces the input tensor.
Not
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Not
(
const
Output
<
Node
>&
arg
);
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/not_equal.cpp
View file @
2243df56
...
...
@@ -19,10 +19,12 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
NotEqual
::
NotEqual
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
string
op
::
NotEqual
::
type_name
{
"NotEqual"
};
op
::
NotEqual
::
NotEqual
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"NotEqual"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/not_equal.hpp
View file @
2243df56
...
...
@@ -26,17 +26,24 @@ namespace ngraph
class
NotEqual
:
public
util
::
BinaryElementwiseComparison
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a not-equal operation.
NotEqual
()
=
default
;
/// \brief Constructs a not-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
NotEqual
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
NotEqual
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
};
}
}
src/ngraph/op/one_hot.cpp
View file @
2243df56
...
...
@@ -20,8 +20,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
OneHot
::
OneHot
(
const
shared_ptr
<
Node
>&
arg
,
const
PartialShape
&
shape
,
size_t
one_hot_axis
)
:
Op
(
"OneHot"
,
check_single_output_args
({
arg
}))
const
string
op
::
OneHot
::
type_name
{
"OneHot"
};
op
::
OneHot
::
OneHot
(
const
Output
<
Node
>&
arg
,
const
PartialShape
&
shape
,
size_t
one_hot_axis
)
:
Op
({
arg
})
,
m_shape
(
shape
)
,
m_one_hot_axis
(
one_hot_axis
)
{
...
...
src/ngraph/op/one_hot.hpp
View file @
2243df56
...
...
@@ -45,14 +45,17 @@ namespace ngraph
class
OneHot
:
public
Op
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a one-hot operation.
OneHot
()
=
default
;
/// \brief Constructs a one-hot operation.
///
/// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
PartialShape
&
shape
,
size_t
one_hot_axis
);
OneHot
(
const
Output
<
Node
>&
arg
,
const
PartialShape
&
shape
,
size_t
one_hot_axis
);
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
@@ -60,6 +63,7 @@ namespace ngraph
/// \return The index of the one-hot axis.
size_t
get_one_hot_axis
()
const
{
return
m_one_hot_axis
;
}
void
set_one_hot_axis
(
size_t
one_hot_axis
)
{
m_one_hot_axis
=
one_hot_axis
;
}
protected
:
PartialShape
m_shape
;
size_t
m_one_hot_axis
;
...
...
src/ngraph/op/or.cpp
View file @
2243df56
...
...
@@ -19,10 +19,10 @@
using
namespace
std
;
using
namespace
ngraph
;
op
::
Or
::
Or
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseLogical
(
"Or"
,
arg0
,
arg1
,
autob
)
const
string
op
::
Or
::
type_name
{
"Or"
};
op
::
Or
::
Or
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseLogical
(
arg0
,
arg1
,
autob
)
{
constructor_validate_and_infer_types
();
}
...
...
src/ngraph/op/or.hpp
View file @
2243df56
...
...
@@ -29,6 +29,9 @@ namespace ngraph
class
Or
:
public
util
::
BinaryElementwiseLogical
{
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a logical-or operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
...
...
@@ -39,15 +42,14 @@ namespace ngraph
///
/// Output `[d0, ...]`
///
Or
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
Or
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
virtual
bool
is_commutative
()
override
{
return
true
;
}
virtual
bool
is_commutative
()
const
override
{
return
true
;
}
};
}
}
src/ngraph/runtime/plaidml/CMakeLists.txt
View file @
2243df56
...
...
@@ -74,3 +74,11 @@ target_include_directories(plaidml_backend SYSTEM PUBLIC ${PLAIDML_INCLUDE_DIRS}
target_link_libraries
(
plaidml_backend PUBLIC ngraph libplaidml
)
install
(
TARGETS plaidml_backend LIBRARY DESTINATION
${
NGRAPH_INSTALL_LIB
}
)
set
(
CMAKE_MACOSX_RPATH 1
)
if
(
APPLE
)
set_property
(
TARGET plaidml_backend PROPERTY INSTALL_RPATH
"@loader_path/;@loader_path/../../.."
)
elseif
(
DEFINED NGRAPH_RPATH
)
set_property
(
TARGET plaidml_backend PROPERTY INSTALL_RPATH
"
\$
ORIGIN;
\$
ORIGIN/../../..;
${
NGRAPH_RPATH
}
"
)
else
()
set_property
(
TARGET plaidml_backend PROPERTY INSTALL_RPATH
"
\$
ORIGIN;
\$
ORIGIN/../../.."
)
endif
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment