Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
0552fd84
Unverified
Commit
0552fd84
authored
Jun 25, 2019
by
Scott Cyphers
Committed by
GitHub
Jun 25, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Convert to new op form (#3112)
parent
5e7aacf1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
64 changed files
with
410 additions
and
252 deletions
+410
-252
node.cpp
src/ngraph/node.cpp
+10
-0
node.hpp
src/ngraph/node.hpp
+2
-0
abs.cpp
src/ngraph/op/abs.cpp
+0
-4
abs.hpp
src/ngraph/op/abs.hpp
+1
-1
acos.cpp
src/ngraph/op/acos.cpp
+0
-4
acos.hpp
src/ngraph/op/acos.hpp
+1
-1
add.cpp
src/ngraph/op/add.cpp
+0
-4
add.hpp
src/ngraph/op/add.hpp
+1
-1
all.cpp
src/ngraph/op/all.cpp
+0
-4
all.hpp
src/ngraph/op/all.hpp
+1
-1
and.cpp
src/ngraph/op/and.cpp
+0
-4
and.hpp
src/ngraph/op/and.hpp
+1
-1
any.cpp
src/ngraph/op/any.cpp
+0
-4
any.hpp
src/ngraph/op/any.hpp
+1
-1
argmax.cpp
src/ngraph/op/argmax.cpp
+0
-4
argmax.hpp
src/ngraph/op/argmax.hpp
+1
-1
argmin.cpp
src/ngraph/op/argmin.cpp
+0
-4
argmin.hpp
src/ngraph/op/argmin.hpp
+1
-1
asin.cpp
src/ngraph/op/asin.cpp
+0
-4
asin.hpp
src/ngraph/op/asin.hpp
+1
-1
atan.cpp
src/ngraph/op/atan.cpp
+0
-4
atan.hpp
src/ngraph/op/atan.hpp
+1
-1
avg_pool.cpp
src/ngraph/op/avg_pool.cpp
+0
-8
avg_pool.hpp
src/ngraph/op/avg_pool.hpp
+2
-2
batch_norm.cpp
src/ngraph/op/batch_norm.cpp
+42
-40
batch_norm.hpp
src/ngraph/op/batch_norm.hpp
+42
-28
broadcast.cpp
src/ngraph/op/broadcast.cpp
+11
-10
broadcast.hpp
src/ngraph/op/broadcast.hpp
+25
-10
broadcast_distributed.cpp
src/ngraph/op/broadcast_distributed.cpp
+9
-2
broadcast_distributed.hpp
src/ngraph/op/broadcast_distributed.hpp
+7
-2
ceiling.cpp
src/ngraph/op/ceiling.cpp
+4
-2
ceiling.hpp
src/ngraph/op/ceiling.hpp
+6
-1
concat.cpp
src/ngraph/op/concat.cpp
+9
-2
concat.hpp
src/ngraph/op/concat.hpp
+17
-1
constant.cpp
src/ngraph/op/constant.cpp
+2
-0
constant.hpp
src/ngraph/op/constant.hpp
+5
-2
convert.cpp
src/ngraph/op/convert.cpp
+4
-2
convert.hpp
src/ngraph/op/convert.hpp
+12
-2
convolution.cpp
src/ngraph/op/convolution.cpp
+27
-21
convolution.hpp
src/ngraph/op/convolution.hpp
+0
-0
cos.cpp
src/ngraph/op/cos.cpp
+4
-2
cos.hpp
src/ngraph/op/cos.hpp
+6
-1
cosh.cpp
src/ngraph/op/cosh.cpp
+4
-2
cosh.hpp
src/ngraph/op/cosh.hpp
+6
-1
dequantize.cpp
src/ngraph/op/dequantize.cpp
+6
-4
dequantize.hpp
src/ngraph/op/dequantize.hpp
+17
-8
divide.cpp
src/ngraph/op/divide.cpp
+9
-8
divide.hpp
src/ngraph/op/divide.hpp
+13
-7
dot.cpp
src/ngraph/op/dot.cpp
+7
-5
dot.hpp
src/ngraph/op/dot.hpp
+17
-3
embedding_lookup.cpp
src/ngraph/op/embedding_lookup.cpp
+2
-0
embedding_lookup.hpp
src/ngraph/op/embedding_lookup.hpp
+7
-2
equal.cpp
src/ngraph/op/equal.cpp
+4
-4
equal.hpp
src/ngraph/op/equal.hpp
+8
-3
erf.cpp
src/ngraph/op/erf.cpp
+4
-2
erf.hpp
src/ngraph/op/erf.hpp
+5
-1
exp.cpp
src/ngraph/op/exp.cpp
+4
-2
exp.hpp
src/ngraph/op/exp.hpp
+6
-1
floor.cpp
src/ngraph/op/floor.cpp
+4
-2
floor.hpp
src/ngraph/op/floor.hpp
+6
-1
reshape.cpp
src/ngraph/op/reshape.cpp
+4
-2
reshape.hpp
src/ngraph/op/reshape.hpp
+11
-3
result.cpp
src/ngraph/op/result.cpp
+4
-2
result.hpp
src/ngraph/op/result.hpp
+6
-1
No files found.
src/ngraph/node.cpp
View file @
0552fd84
...
@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
...
@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return
args
;
return
args
;
}
}
OutputVector
ngraph
::
as_output_vector
(
const
NodeVector
&
args
)
{
OutputVector
output_vector
;
for
(
auto
&
arg
:
check_single_output_args
(
args
))
{
output_vector
.
push_back
(
arg
);
}
return
output_vector
;
}
std
::
tuple
<
element
::
Type
,
PartialShape
>
std
::
tuple
<
element
::
Type
,
PartialShape
>
Node
::
validate_and_infer_elementwise_args
(
const
op
::
AutoBroadcastSpec
&
autob
)
Node
::
validate_and_infer_elementwise_args
(
const
op
::
AutoBroadcastSpec
&
autob
)
{
{
...
...
src/ngraph/node.hpp
View file @
0552fd84
...
@@ -73,6 +73,8 @@ namespace ngraph
...
@@ -73,6 +73,8 @@ namespace ngraph
size_t
i
);
size_t
i
);
const
NodeVector
&
check_single_output_args
(
const
NodeVector
&
args
);
const
NodeVector
&
check_single_output_args
(
const
NodeVector
&
args
);
OutputVector
as_output_vector
(
const
NodeVector
&
args
);
/// Alias useful for cloning
/// Alias useful for cloning
using
NodeMap
=
std
::
unordered_map
<
ngraph
::
Node
*
,
std
::
shared_ptr
<
ngraph
::
Node
>>
;
using
NodeMap
=
std
::
unordered_map
<
ngraph
::
Node
*
,
std
::
shared_ptr
<
ngraph
::
Node
>>
;
...
...
src/ngraph/op/abs.cpp
View file @
0552fd84
...
@@ -23,10 +23,6 @@ using namespace ngraph;
...
@@ -23,10 +23,6 @@ using namespace ngraph;
const
string
op
::
Abs
::
type_name
{
"Abs"
};
const
string
op
::
Abs
::
type_name
{
"Abs"
};
op
::
Abs
::
Abs
()
{
}
op
::
Abs
::
Abs
(
const
Output
<
Node
>&
arg
)
op
::
Abs
::
Abs
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
...
...
src/ngraph/op/abs.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an absolute value operation.
/// \brief Constructs an absolute value operation.
Abs
();
Abs
()
=
default
;
/// \brief Constructs an absolute value operation.
/// \brief Constructs an absolute value operation.
///
///
...
...
src/ngraph/op/acos.cpp
View file @
0552fd84
...
@@ -34,10 +34,6 @@ using namespace ngraph;
...
@@ -34,10 +34,6 @@ using namespace ngraph;
const
string
op
::
Acos
::
type_name
{
"Acos"
};
const
string
op
::
Acos
::
type_name
{
"Acos"
};
op
::
Acos
::
Acos
()
{
}
op
::
Acos
::
Acos
(
const
Output
<
Node
>&
arg
)
op
::
Acos
::
Acos
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
...
...
src/ngraph/op/acos.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an arccos operation.
/// \brief Constructs an arccos operation.
Acos
();
Acos
()
=
default
;
/// \brief Constructs an arccos operation.
/// \brief Constructs an arccos operation.
///
///
/// \param arg Output that produces the input tensor.<br>
/// \param arg Output that produces the input tensor.<br>
...
...
src/ngraph/op/add.cpp
View file @
0552fd84
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
Add
::
type_name
{
"Add"
};
const
string
op
::
Add
::
type_name
{
"Add"
};
op
::
Add
::
Add
()
{
}
op
::
Add
::
Add
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
op
::
Add
::
Add
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
{
...
...
src/ngraph/op/add.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an unitialized addition operation
/// \brief Constructs an unitialized addition operation
Add
();
Add
()
=
default
;
/// \brief Constructs an addition operation.
/// \brief Constructs an addition operation.
///
///
...
...
src/ngraph/op/all.cpp
View file @
0552fd84
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
All
::
type_name
{
"All"
};
const
string
op
::
All
::
type_name
{
"All"
};
op
::
All
::
All
()
{
}
op
::
All
::
All
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
op
::
All
::
All
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
LogicalReduction
(
arg
,
reduction_axes
)
:
LogicalReduction
(
arg
,
reduction_axes
)
{
{
...
...
src/ngraph/op/all.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an "all" reduction operation.
/// \brief Constructs an "all" reduction operation.
All
();
All
()
=
default
;
/// \brief Constructs an "all" reduction operation.
/// \brief Constructs an "all" reduction operation.
///
///
/// \param arg The tensor to be reduced.
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/and.cpp
View file @
0552fd84
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
And
::
type_name
{
"And"
};
const
string
op
::
And
::
type_name
{
"And"
};
op
::
And
::
And
()
{
}
op
::
And
::
And
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
op
::
And
::
And
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseLogical
(
arg0
,
arg1
,
autob
)
:
BinaryElementwiseLogical
(
arg0
,
arg1
,
autob
)
{
{
...
...
src/ngraph/op/and.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a logical-and operation.
/// \brief Constructs a logical-and operation.
And
();
And
()
=
default
;
/// \brief Constructs a logical-and operation.
/// \brief Constructs a logical-and operation.
///
///
...
...
src/ngraph/op/any.cpp
View file @
0552fd84
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
Any
::
type_name
{
"Any"
};
const
string
op
::
Any
::
type_name
{
"Any"
};
op
::
Any
::
Any
()
{
}
op
::
Any
::
Any
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
op
::
Any
::
Any
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
LogicalReduction
(
arg
,
reduction_axes
)
:
LogicalReduction
(
arg
,
reduction_axes
)
{
{
...
...
src/ngraph/op/any.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an "any" reduction operation.
/// \brief Constructs an "any" reduction operation.
Any
();
Any
()
=
default
;
/// \brief Constructs an "any" reduction operation.
/// \brief Constructs an "any" reduction operation.
///
///
/// \param arg The tensor to be reduced.
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/argmax.cpp
View file @
0552fd84
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
ArgMax
::
type_name
{
"ArgMax"
};
const
string
op
::
ArgMax
::
type_name
{
"ArgMax"
};
op
::
ArgMax
::
ArgMax
()
{
}
op
::
ArgMax
::
ArgMax
(
const
Output
<
Node
>&
arg
,
size_t
axis
,
const
element
::
Type
&
index_element_type
)
op
::
ArgMax
::
ArgMax
(
const
Output
<
Node
>&
arg
,
size_t
axis
,
const
element
::
Type
&
index_element_type
)
:
op
::
util
::
IndexReduction
(
arg
,
axis
,
index_element_type
)
:
op
::
util
::
IndexReduction
(
arg
,
axis
,
index_element_type
)
{
{
...
...
src/ngraph/op/argmax.hpp
View file @
0552fd84
...
@@ -32,7 +32,7 @@ namespace ngraph
...
@@ -32,7 +32,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a ArgMax operation.
/// \brief Constructs a ArgMax operation.
ArgMax
();
ArgMax
()
=
default
;
/// \brief Constructs a ArgMax operation.
/// \brief Constructs a ArgMax operation.
///
///
/// \param arg The input tensor
/// \param arg The input tensor
...
...
src/ngraph/op/argmin.cpp
View file @
0552fd84
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
ArgMin
::
type_name
{
"ArgMin"
};
const
string
op
::
ArgMin
::
type_name
{
"ArgMin"
};
op
::
ArgMin
::
ArgMin
()
{
}
op
::
ArgMin
::
ArgMin
(
const
Output
<
Node
>&
arg
,
size_t
axis
,
const
element
::
Type
&
index_element_type
)
op
::
ArgMin
::
ArgMin
(
const
Output
<
Node
>&
arg
,
size_t
axis
,
const
element
::
Type
&
index_element_type
)
:
op
::
util
::
IndexReduction
(
arg
,
axis
,
index_element_type
)
:
op
::
util
::
IndexReduction
(
arg
,
axis
,
index_element_type
)
{
{
...
...
src/ngraph/op/argmin.hpp
View file @
0552fd84
...
@@ -32,7 +32,7 @@ namespace ngraph
...
@@ -32,7 +32,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a ArgMin operation.
/// \brief Constructs a ArgMin operation.
ArgMin
();
ArgMin
()
=
default
;
/// \brief Constructs a ArgMin operation.
/// \brief Constructs a ArgMin operation.
///
///
...
...
src/ngraph/op/asin.cpp
View file @
0552fd84
...
@@ -33,10 +33,6 @@ using namespace ngraph;
...
@@ -33,10 +33,6 @@ using namespace ngraph;
const
string
op
::
Asin
::
type_name
{
"Asin"
};
const
string
op
::
Asin
::
type_name
{
"Asin"
};
op
::
Asin
::
Asin
()
{
}
op
::
Asin
::
Asin
(
const
Output
<
Node
>&
arg
)
op
::
Asin
::
Asin
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
...
...
src/ngraph/op/asin.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an arcsin operation.
/// \brief Constructs an arcsin operation.
Asin
();
Asin
()
=
default
;
/// \brief Constructs an arcsin operation.
/// \brief Constructs an arcsin operation.
///
///
/// \param arg Output that produces the input tensor.<br>
/// \param arg Output that produces the input tensor.<br>
...
...
src/ngraph/op/atan.cpp
View file @
0552fd84
...
@@ -32,10 +32,6 @@ using namespace ngraph;
...
@@ -32,10 +32,6 @@ using namespace ngraph;
const
string
op
::
Atan
::
type_name
{
"Atan"
};
const
string
op
::
Atan
::
type_name
{
"Atan"
};
op
::
Atan
::
Atan
()
{
}
op
::
Atan
::
Atan
(
const
Output
<
Node
>&
arg
)
op
::
Atan
::
Atan
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
...
...
src/ngraph/op/atan.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an arctan operation.
/// \brief Constructs an arctan operation.
Atan
();
Atan
()
=
default
;
/// \brief Constructs an arctan operation.
/// \brief Constructs an arctan operation.
///
///
...
...
src/ngraph/op/avg_pool.cpp
View file @
0552fd84
...
@@ -23,10 +23,6 @@ using namespace ngraph;
...
@@ -23,10 +23,6 @@ using namespace ngraph;
const
string
op
::
AvgPool
::
type_name
{
"AvgPool"
};
const
string
op
::
AvgPool
::
type_name
{
"AvgPool"
};
op
::
AvgPool
::
AvgPool
()
{
}
op
::
AvgPool
::
AvgPool
(
const
Output
<
Node
>&
arg
,
op
::
AvgPool
::
AvgPool
(
const
Output
<
Node
>&
arg
,
const
Shape
&
window_shape
,
const
Shape
&
window_shape
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
...
@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
...
@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
const
string
op
::
AvgPoolBackprop
::
type_name
(
"AvgPoolBackprop"
);
const
string
op
::
AvgPoolBackprop
::
type_name
(
"AvgPoolBackprop"
);
op
::
AvgPoolBackprop
::
AvgPoolBackprop
()
{
}
op
::
AvgPoolBackprop
::
AvgPoolBackprop
(
const
Shape
&
forward_arg_shape
,
op
::
AvgPoolBackprop
::
AvgPoolBackprop
(
const
Shape
&
forward_arg_shape
,
const
shared_ptr
<
Node
>&
delta
,
const
shared_ptr
<
Node
>&
delta
,
const
Shape
&
window_shape
,
const
Shape
&
window_shape
,
...
...
src/ngraph/op/avg_pool.hpp
View file @
0552fd84
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a batched average pooling operation.
/// \brief Constructs a batched average pooling operation.
AvgPool
();
AvgPool
()
=
default
;
/// \brief Constructs a batched average pooling operation.
/// \brief Constructs a batched average pooling operation.
///
///
...
@@ -175,7 +175,7 @@ namespace ngraph
...
@@ -175,7 +175,7 @@ namespace ngraph
public
:
public
:
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
AvgPoolBackprop
();
AvgPoolBackprop
()
=
default
;
AvgPoolBackprop
(
const
Shape
&
forward_arg_shape
,
AvgPoolBackprop
(
const
Shape
&
forward_arg_shape
,
const
std
::
shared_ptr
<
Node
>&
delta
,
const
std
::
shared_ptr
<
Node
>&
delta
,
const
Shape
&
window_shape
,
const
Shape
&
window_shape
,
...
...
src/ngraph/op/batch_norm.cpp
View file @
0552fd84
...
@@ -22,11 +22,13 @@
...
@@ -22,11 +22,13 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
#include "ngraph/validation_util.hpp"
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
const
std
::
string
ngraph
::
op
::
BatchNormTraining
::
type_name
{
"BatchNormTraining"
};
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
Output
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
double
epsilon
)
double
epsilon
)
:
Op
(
"BatchNormTraining"
,
check_single_output_args
({
gamma
,
beta
,
input
})
)
:
Op
(
{
gamma
,
beta
,
input
}
)
,
m_epsilon
(
epsilon
)
,
m_epsilon
(
epsilon
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
...
@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
// DEPRECATED
// DEPRECATED
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
double
eps
,
ngraph
::
op
::
BatchNormTraining
::
BatchNormTraining
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
)
Output
<
ngraph
::
Node
>
input
)
:
Op
(
"BatchNormTraining"
,
check_single_output_args
({
gamma
,
beta
,
input
})
)
:
Op
(
{
gamma
,
beta
,
input
}
)
,
m_epsilon
(
eps
)
,
m_epsilon
(
eps
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
...
@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints
.
add_delta
(
beta
,
dbeta
);
adjoints
.
add_delta
(
beta
,
dbeta
);
}
}
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
const
std
::
string
ngraph
::
op
::
BatchNormInference
::
type_name
{
"BatchNormInference"
};
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
Output
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
double
epsilon
)
double
epsilon
)
:
Op
(
"BatchNormInference"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
})
)
:
Op
(
{
gamma
,
beta
,
input
,
mean
,
variance
}
)
,
m_epsilon
(
epsilon
)
,
m_epsilon
(
epsilon
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
...
@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
// DEPRECATED
// DEPRECATED
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
double
eps
,
ngraph
::
op
::
BatchNormInference
::
BatchNormInference
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
)
Output
<
ngraph
::
Node
>
variance
)
:
Op
(
"BatchNormInference"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
})
)
:
Op
(
{
gamma
,
beta
,
input
,
mean
,
variance
}
)
,
m_epsilon
(
eps
)
,
m_epsilon
(
eps
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
...
@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
new_args
.
at
(
2
),
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
m_epsilon
);
new_args
.
at
(
2
),
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
m_epsilon
);
}
}
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
const
std
::
string
ngraph
::
op
::
BatchNormTrainingBackprop
::
type_name
{
"BatchNormTrainingBackprop"
};
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
Output
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
bet
a
,
Output
<
ngraph
::
Node
>
gamm
a
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
delta
,
Output
<
ngraph
::
Node
>
variance
,
double
epsilon
)
Output
<
ngraph
::
Node
>
delta
,
:
Op
(
"BatchNormTrainingBackprop"
,
double
epsilon
)
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
})
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
}
)
,
m_epsilon
(
epsilon
)
,
m_epsilon
(
epsilon
)
{
{
...
@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
...
@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
ngraph
::
op
::
BatchNormTrainingBackprop
::
BatchNormTrainingBackprop
(
double
epsilon
,
double
epsilon
,
Output
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
variance
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
delta
)
std
::
shared_ptr
<
ngraph
::
Node
>
delta
)
:
Op
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
})
:
Op
(
"BatchNormTrainingBackprop"
,
check_single_output_args
({
gamma
,
beta
,
input
,
mean
,
variance
,
delta
}))
,
m_epsilon
(
epsilon
)
,
m_epsilon
(
epsilon
)
{
{
...
...
src/ngraph/op/batch_norm.hpp
View file @
0552fd84
...
@@ -31,13 +31,17 @@ namespace ngraph
...
@@ -31,13 +31,17 @@ namespace ngraph
class
BatchNormTraining
:
public
Op
class
BatchNormTraining
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BatchNormTraining
()
=
default
;
/// \param input Must have rank >= 2, [., C, ...]
/// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining
(
std
::
shared_ptr
<
Node
>
input
,
BatchNormTraining
(
Output
<
Node
>
input
,
std
::
shared_ptr
<
Node
>
gamma
,
Output
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
Output
<
Node
>
beta
,
double
epsilon
);
double
epsilon
);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED_DOC
...
@@ -62,13 +66,14 @@ namespace ngraph
...
@@ -62,13 +66,14 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED
(
"Use another constructor"
)
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormTraining
(
double
eps
,
BatchNormTraining
(
double
eps
,
std
::
shared_ptr
<
Node
>
gamma
,
Output
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
Output
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
input
);
Output
<
Node
>
input
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
double
get_eps_value
()
const
{
return
m_epsilon
;
}
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
@@ -87,17 +92,20 @@ namespace ngraph
...
@@ -87,17 +92,20 @@ namespace ngraph
class
BatchNormInference
:
public
Op
class
BatchNormInference
:
public
Op
{
{
public
:
public
:
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BatchNormInference
()
=
default
;
/// \param input [., C, ...]
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference
(
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
BatchNormInference
(
Output
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
,
Output
<
ngraph
::
Node
>
variance
,
double
epsilon
);
double
epsilon
);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED_DOC
...
@@ -120,15 +128,16 @@ namespace ngraph
...
@@ -120,15 +128,16 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED
(
"Use another constructor"
)
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormInference
(
double
eps
,
BatchNormInference
(
double
eps
,
std
::
shared_ptr
<
ngraph
::
Node
>
gamma
,
Output
<
ngraph
::
Node
>
gamma
,
std
::
shared_ptr
<
ngraph
::
Node
>
beta
,
Output
<
ngraph
::
Node
>
beta
,
std
::
shared_ptr
<
ngraph
::
Node
>
input
,
Output
<
ngraph
::
Node
>
input
,
std
::
shared_ptr
<
ngraph
::
Node
>
mean
,
Output
<
ngraph
::
Node
>
mean
,
std
::
shared_ptr
<
ngraph
::
Node
>
variance
);
Output
<
ngraph
::
Node
>
variance
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
double
get_eps_value
()
const
{
return
m_epsilon
;
}
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
@@ -152,28 +161,33 @@ namespace ngraph
...
@@ -152,28 +161,33 @@ namespace ngraph
class
BatchNormTrainingBackprop
:
public
Op
class
BatchNormTrainingBackprop
:
public
Op
{
{
public
:
public
:
BatchNormTrainingBackprop
(
std
::
shared_ptr
<
Node
>
input
,
NGRAPH_API
std
::
shared_ptr
<
Node
>
gamma
,
static
const
std
::
string
type_name
;
std
::
shared_ptr
<
Node
>
beta
,
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
std
::
shared_ptr
<
Node
>
mean
,
BatchNormTrainingBackprop
()
=
default
;
std
::
shared_ptr
<
Node
>
variance
,
BatchNormTrainingBackprop
(
Output
<
Node
>
input
,
std
::
shared_ptr
<
Node
>
delta
,
Output
<
Node
>
gamma
,
Output
<
Node
>
beta
,
Output
<
Node
>
mean
,
Output
<
Node
>
variance
,
Output
<
Node
>
delta
,
double
epsilon
);
double
epsilon
);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED
(
"Use another constructor"
)
NGRAPH_DEPRECATED
(
"Use another constructor"
)
BatchNormTrainingBackprop
(
double
epsilon
,
BatchNormTrainingBackprop
(
double
epsilon
,
std
::
shared_ptr
<
Node
>
gamma
,
Output
<
Node
>
gamma
,
std
::
shared_ptr
<
Node
>
beta
,
Output
<
Node
>
beta
,
std
::
shared_ptr
<
Node
>
input
,
Output
<
Node
>
input
,
std
::
shared_ptr
<
Node
>
mean
,
Output
<
Node
>
mean
,
std
::
shared_ptr
<
Node
>
variance
,
Output
<
Node
>
variance
,
std
::
shared_ptr
<
Node
>
delta
);
Output
<
Node
>
delta
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
double
get_eps_value
()
const
{
return
m_epsilon
;
}
double
get_eps_value
()
const
{
return
m_epsilon
;
}
void
set_eps_value
(
double
epsilon
)
{
m_epsilon
=
epsilon
;
}
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/broadcast.cpp
View file @
0552fd84
...
@@ -20,21 +20,20 @@
...
@@ -20,21 +20,20 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Broadcast
::
Broadcast
(
const
std
::
string
&
name
,
const
string
op
::
Broadcast
::
type_name
{
"Broadcast"
};
const
NodeVector
&
args
,
op
::
Broadcast
::
Broadcast
(
const
OutputVector
&
args
,
const
Shape
&
shape
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
const
AxisSet
&
broadcast_axes
)
:
Op
(
name
,
check_single_output_args
(
args
)
)
:
Op
(
args
)
,
m_shape
(
shape
)
,
m_shape
(
shape
)
,
m_broadcast_axes
(
broadcast_axes
)
,
m_broadcast_axes
(
broadcast_axes
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
op
::
Broadcast
::
Broadcast
(
const
shared_ptr
<
Node
>&
arg
,
op
::
Broadcast
::
Broadcast
(
const
Output
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
)
const
Shape
&
shape
,
:
Broadcast
(
OutputVector
{
arg
},
shape
,
broadcast_axes
)
const
AxisSet
&
broadcast_axes
)
:
Broadcast
(
"Broadcast"
,
{
arg
},
shape
,
broadcast_axes
)
{
{
}
}
...
@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
...
@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Sum
>
(
delta
,
m_broadcast_axes
));
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Sum
>
(
delta
,
m_broadcast_axes
));
}
}
op
::
BroadcastLike
::
BroadcastLike
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
string
op
::
BroadcastLike
::
type_name
{
"BroadcastLike"
};
const
std
::
shared_ptr
<
Node
>&
like_arg
,
op
::
BroadcastLike
::
BroadcastLike
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
like_arg
,
const
AxisSet
&
initial_broadcast_axes
)
const
AxisSet
&
initial_broadcast_axes
)
:
Broadcast
(
"BroadcastLike"
,
{
arg
,
like_arg
},
{},
{})
:
Broadcast
({
arg
,
like_arg
},
{},
{})
,
m_initial_broadcast_axes
(
initial_broadcast_axes
)
,
m_initial_broadcast_axes
(
initial_broadcast_axes
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/broadcast.hpp
View file @
0552fd84
...
@@ -27,15 +27,18 @@ namespace ngraph
...
@@ -27,15 +27,18 @@ namespace ngraph
class
Broadcast
:
public
Op
class
Broadcast
:
public
Op
{
{
public
:
public
:
/// \brief Constructs a conversion operation.
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a broadcast operation.
Broadcast
()
=
default
;
/// \brief Constructs a broadcast operation.
///
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg.
/// remaining axes in shape must be the same as the shape of arg.
Broadcast
(
const
std
::
shared_ptr
<
Node
>&
arg
,
Broadcast
(
const
Output
<
Node
>&
arg
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
@@ -44,12 +47,14 @@ namespace ngraph
...
@@ -44,12 +47,14 @@ namespace ngraph
/// \return A set containing the indices of the broadcast axes (0-based).
/// \return A set containing the indices of the broadcast axes (0-based).
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
void
set_broadcast_axes
(
const
AxisSet
&
broadcast_axes
)
{
m_broadcast_axes
=
broadcast_axes
;
}
const
Shape
&
get_broadcast_shape
()
const
{
return
m_shape
;
}
const
Shape
&
get_broadcast_shape
()
const
{
return
m_shape
;
}
void
set_broadcast_shape
(
const
Shape
&
shape
)
{
m_shape
=
shape
;
}
protected
:
protected
:
Broadcast
(
const
std
::
string
&
node_type
,
Broadcast
(
const
OutputVector
&
args
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
const
NodeVector
&
args
,
const
Shape
&
shape
,
const
AxisSet
&
broadcast_axes
);
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
...
@@ -63,6 +68,11 @@ namespace ngraph
...
@@ -63,6 +68,11 @@ namespace ngraph
class
BroadcastLike
:
public
Broadcast
class
BroadcastLike
:
public
Broadcast
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike
()
=
default
;
/// \brief Broadcast arg to the same shape as like_arg.
/// \brief Broadcast arg to the same shape as like_arg.
///
///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
...
@@ -72,8 +82,8 @@ namespace ngraph
...
@@ -72,8 +82,8 @@ namespace ngraph
/// \param like_arg Provides the shape for the result.
/// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
/// arg must be scalar and all axes are broadcast.
BroadcastLike
(
const
std
::
shared_ptr
<
Node
>&
arg
,
BroadcastLike
(
const
Output
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
like_arg
,
const
Output
<
Node
>&
like_arg
,
const
AxisSet
&
initial_broadcast_axes
);
const
AxisSet
&
initial_broadcast_axes
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
@@ -81,6 +91,11 @@ namespace ngraph
...
@@ -81,6 +91,11 @@ namespace ngraph
void
infer_shape
()
override
;
void
infer_shape
()
override
;
const
AxisSet
&
get_initial_broadcast_axes
()
const
{
return
m_initial_broadcast_axes
;
}
const
AxisSet
&
get_initial_broadcast_axes
()
const
{
return
m_initial_broadcast_axes
;
}
void
set_initial_broadcast_axes
(
const
AxisSet
&
initial_broadcast_axes
)
{
m_initial_broadcast_axes
=
initial_broadcast_axes
;
}
protected
:
protected
:
AxisSet
m_initial_broadcast_axes
;
AxisSet
m_initial_broadcast_axes
;
};
};
...
...
src/ngraph/op/broadcast_distributed.cpp
View file @
0552fd84
...
@@ -19,8 +19,10 @@
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
BroadcastDistributed
::
BroadcastDistributed
(
const
shared_ptr
<
Node
>&
arg
,
int
root_id
)
const
string
op
::
BroadcastDistributed
::
type_name
{
"BroadcastDistributed"
};
:
Op
(
"BroadcastDistributed"
,
check_single_output_args
({
arg
}))
op
::
BroadcastDistributed
::
BroadcastDistributed
(
const
Output
<
Node
>&
arg
,
int
root_id
)
:
Op
({
arg
})
,
m_root_id
(
root_id
)
,
m_root_id
(
root_id
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
...
@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
{
{
return
m_root_id
;
return
m_root_id
;
}
}
void
op
::
BroadcastDistributed
::
set_root_id
(
int
root_id
)
{
m_root_id
=
root_id
;
}
src/ngraph/op/broadcast_distributed.hpp
View file @
0552fd84
...
@@ -27,16 +27,21 @@ namespace ngraph
...
@@ -27,16 +27,21 @@ namespace ngraph
class
BroadcastDistributed
:
public
Op
class
BroadcastDistributed
:
public
Op
{
{
public
:
public
:
BroadcastDistributed
(
const
std
::
shared_ptr
<
Node
>&
arg
,
int
root_id
=
0
);
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
BroadcastDistributed
()
=
default
;
BroadcastDistributed
(
const
Output
<
Node
>&
arg
,
int
root_id
=
0
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
int
get_root_id
()
const
;
int
get_root_id
()
const
;
void
set_root_id
(
int
root_id
);
private
:
private
:
const
int
m_root_id
;
int
m_root_id
;
};
};
}
}
}
}
src/ngraph/op/ceiling.cpp
View file @
0552fd84
...
@@ -19,8 +19,10 @@
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Ceiling
::
Ceiling
(
const
shared_ptr
<
Node
>&
arg
)
const
string
op
::
Ceiling
::
type_name
{
"Ceiling"
};
:
UnaryElementwiseArithmetic
(
"Ceiling"
,
arg
)
op
::
Ceiling
::
Ceiling
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/ceiling.hpp
View file @
0552fd84
...
@@ -26,10 +26,15 @@ namespace ngraph
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Ceiling
:
public
util
::
UnaryElementwiseArithmetic
class
Ceiling
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a ceiling operation.
Ceiling
()
=
default
;
/// \brief Constructs a ceiling operation.
/// \brief Constructs a ceiling operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Ceiling
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Ceiling
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/concat.cpp
View file @
0552fd84
...
@@ -22,13 +22,20 @@
...
@@ -22,13 +22,20 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Concat
::
Concat
(
const
NodeVector
&
args
,
size_t
concatenation_axis
)
const
string
op
::
Concat
::
type_name
{
"Concat"
};
:
Op
(
"Concat"
,
check_single_output_args
(
args
))
op
::
Concat
::
Concat
(
const
OutputVector
&
args
,
size_t
concatenation_axis
)
:
Op
(
args
)
,
m_concatenation_axis
(
concatenation_axis
)
,
m_concatenation_axis
(
concatenation_axis
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
op
::
Concat
::
Concat
(
const
NodeVector
&
args
,
size_t
concatenation_axis
)
:
Concat
(
as_output_vector
(
args
),
concatenation_axis
)
{
}
void
op
::
Concat
::
validate_and_infer_types
()
void
op
::
Concat
::
validate_and_infer_types
()
{
{
NODE_VALIDATION_CHECK
(
this
,
get_input_size
()
>=
1
,
"At least one argument required."
);
NODE_VALIDATION_CHECK
(
this
,
get_input_size
()
>=
1
,
"At least one argument required."
);
...
...
src/ngraph/op/concat.hpp
View file @
0552fd84
...
@@ -28,6 +28,17 @@ namespace ngraph
...
@@ -28,6 +28,17 @@ namespace ngraph
class
Concat
:
public
Op
class
Concat
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a concatenation operation.
Concat
()
=
default
;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat
(
const
OutputVector
&
args
,
size_t
concatenation_axis
);
/// \brief Constructs a concatenation operation.
/// \brief Constructs a concatenation operation.
///
///
/// \param args The nodes producing the input tensors.
/// \param args The nodes producing the input tensors.
...
@@ -41,10 +52,15 @@ namespace ngraph
...
@@ -41,10 +52,15 @@ namespace ngraph
/// \return The concatenation axis.
/// \return The concatenation axis.
size_t
get_concatenation_axis
()
const
{
return
m_concatenation_axis
;
}
size_t
get_concatenation_axis
()
const
{
return
m_concatenation_axis
;
}
void
set_concatenation_axis
(
size_t
concatenation_axis
)
{
m_concatenation_axis
=
concatenation_axis
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
const
size_t
m_concatenation_axis
;
size_t
m_concatenation_axis
;
};
};
}
}
}
}
src/ngraph/op/constant.cpp
View file @
0552fd84
...
@@ -45,6 +45,8 @@ string to_cpp_string(T value)
...
@@ -45,6 +45,8 @@ string to_cpp_string(T value)
return
rc
;
return
rc
;
}
}
const
string
op
::
Constant
::
type_name
{
"Constant"
};
op
::
Constant
::~
Constant
()
op
::
Constant
::~
Constant
()
{
{
}
}
...
...
src/ngraph/op/constant.hpp
View file @
0552fd84
...
@@ -34,6 +34,9 @@ namespace ngraph
...
@@ -34,6 +34,9 @@ namespace ngraph
class
Constant
:
public
Node
class
Constant
:
public
Node
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a tensor constant.
/// \brief Constructs a tensor constant.
///
///
/// \param type The element type of the tensor constant.
/// \param type The element type of the tensor constant.
...
@@ -78,7 +81,7 @@ namespace ngraph
...
@@ -78,7 +81,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data.
/// \param values A list of string values to use as the constant data.
Constant
(
const
element
::
Type
&
type
,
Shape
shape
,
const
std
::
vector
<
std
::
string
>&
values
)
Constant
(
const
element
::
Type
&
type
,
Shape
shape
,
const
std
::
vector
<
std
::
string
>&
values
)
:
Node
(
"Constant"
,
{})
:
Node
({})
,
m_element_type
(
type
)
,
m_element_type
(
type
)
,
m_shape
(
shape
)
,
m_shape
(
shape
)
,
m_data
(
new
runtime
::
AlignedBuffer
(
shape_size
(
m_shape
)
*
m_element_type
.
size
(),
,
m_data
(
new
runtime
::
AlignedBuffer
(
shape_size
(
m_shape
)
*
m_element_type
.
size
(),
...
@@ -135,7 +138,7 @@ namespace ngraph
...
@@ -135,7 +138,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param data A void* to constant data.
/// \param data A void* to constant data.
Constant
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
const
void
*
data
)
Constant
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
const
void
*
data
)
:
Node
(
"Constant"
,
{})
:
Node
({})
,
m_element_type
(
type
)
,
m_element_type
(
type
)
,
m_shape
(
shape
)
,
m_shape
(
shape
)
,
m_data
(
nullptr
)
,
m_data
(
nullptr
)
...
...
src/ngraph/op/convert.cpp
View file @
0552fd84
...
@@ -21,8 +21,10 @@
...
@@ -21,8 +21,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Convert
::
Convert
(
const
shared_ptr
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
const
string
op
::
Convert
::
type_name
{
"Convert"
};
:
Op
(
"Convert"
,
check_single_output_args
({
arg
}))
op
::
Convert
::
Convert
(
const
Output
<
Node
>&
arg
,
const
element
::
Type
&
element_type
)
:
Op
({
arg
})
,
m_element_type
(
element_type
)
,
m_element_type
(
element_type
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/convert.hpp
View file @
0552fd84
...
@@ -26,11 +26,16 @@ namespace ngraph
...
@@ -26,11 +26,16 @@ namespace ngraph
class
Convert
:
public
Op
class
Convert
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a conversion operation.
Convert
()
=
default
;
/// \brief Constructs a conversion operation.
/// \brief Constructs a conversion operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
/// \param element_type Element type for the output tensor.
Convert
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
Convert
(
const
Output
<
Node
>&
arg
,
const
ngraph
::
element
::
Type
&
element_type
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
@@ -38,8 +43,13 @@ namespace ngraph
...
@@ -38,8 +43,13 @@ namespace ngraph
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
element
::
Type
&
get_convert_element_type
()
const
{
return
m_element_type
;
}
const
element
::
Type
&
get_convert_element_type
()
const
{
return
m_element_type
;
}
void
set_convert_element_type
(
const
element
::
Type
&
element_type
)
{
m_element_type
=
element_type
;
}
protected
:
protected
:
const
ngraph
::
element
::
Type
m_element_type
;
ngraph
::
element
::
Type
m_element_type
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
};
};
...
...
src/ngraph/op/convolution.cpp
View file @
0552fd84
...
@@ -27,15 +27,17 @@
...
@@ -27,15 +27,17 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
string
op
::
Convolution
::
type_name
{
"Convolution"
};
const
shared_ptr
<
Node
>&
filters
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_above
,
const
CoordinateDiff
&
padding_above
,
const
Strides
&
data_dilation_strides
,
const
Strides
&
data_dilation_strides
,
const
PadType
&
pad_type
)
const
PadType
&
pad_type
)
:
Op
(
"Convolution"
,
check_single_output_args
({
data_batch
,
filters
})
)
:
Op
(
{
data_batch
,
filters
}
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_window_dilation_strides
(
window_dilation_strides
)
,
m_window_dilation_strides
(
window_dilation_strides
)
,
m_padding_below
(
padding_below
)
,
m_padding_below
(
padding_below
)
...
@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
...
@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_below
,
...
@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
...
@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
{
}
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
)
const
Strides
&
window_dilation_strides
)
:
Convolution
(
data_batch
,
:
Convolution
(
data_batch
,
...
@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
...
@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
{
}
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
)
const
Strides
&
window_movement_strides
)
:
Convolution
(
data_batch
,
:
Convolution
(
data_batch
,
filters
,
filters
,
...
@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
...
@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
{
}
}
op
::
Convolution
::
Convolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
shared_ptr
<
Node
>&
filters
)
op
::
Convolution
::
Convolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
)
:
Convolution
(
data_batch
,
filters
,
Strides
(),
Strides
(),
CoordinateDiff
(),
CoordinateDiff
())
:
Convolution
(
data_batch
,
filters
,
Strides
(),
Strides
(),
CoordinateDiff
(),
CoordinateDiff
())
{
{
}
}
...
@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
...
@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
m_data_dilation_strides
));
m_data_dilation_strides
));
}
}
const
string
op
::
ConvolutionBackpropData
::
type_name
{
"ConvolutionBackpropData"
};
op
::
ConvolutionBackpropData
::
ConvolutionBackpropData
(
const
Shape
&
data_batch_shape
,
op
::
ConvolutionBackpropData
::
ConvolutionBackpropData
(
const
Shape
&
data_batch_shape
,
const
shared_ptr
<
Node
>&
filters
,
const
Output
<
Node
>&
filters
,
const
shared_ptr
<
Node
>&
output_delta
,
const
Output
<
Node
>&
output_delta
,
const
Strides
&
window_movement_strides_forward
,
const
Strides
&
window_movement_strides_forward
,
const
Strides
&
window_dilation_strides_forward
,
const
Strides
&
window_dilation_strides_forward
,
const
CoordinateDiff
&
padding_below_forward
,
const
CoordinateDiff
&
padding_below_forward
,
const
CoordinateDiff
&
padding_above_forward
,
const
CoordinateDiff
&
padding_above_forward
,
const
Strides
&
data_dilation_strides_forward
)
const
Strides
&
data_dilation_strides_forward
)
:
Op
(
"ConvolutionBackpropData"
,
check_single_output_args
({
filters
,
output_delta
})
)
:
Op
(
{
filters
,
output_delta
}
)
,
m_data_batch_shape
(
data_batch_shape
)
,
m_data_batch_shape
(
data_batch_shape
)
,
m_window_movement_strides_forward
(
window_movement_strides_forward
)
,
m_window_movement_strides_forward
(
window_movement_strides_forward
)
,
m_window_dilation_strides_forward
(
window_dilation_strides_forward
)
,
m_window_dilation_strides_forward
(
window_dilation_strides_forward
)
...
@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
...
@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
m_data_dilation_strides_forward
[
i
]);
m_data_dilation_strides_forward
[
i
]);
}
}
auto
swap_NC
=
[](
const
shared_ptr
<
Node
>
n
)
{
auto
swap_NC
=
[](
const
Output
<
Node
>&
n
)
{
AxisVector
ax_order
=
ngraph
::
get_default_order
(
n
->
get_shape
());
AxisVector
ax_order
=
ngraph
::
get_default_order
(
n
.
get_shape
());
ax_order
[
0
]
=
1
;
ax_order
[
0
]
=
1
;
ax_order
[
1
]
=
0
;
ax_order
[
1
]
=
0
;
auto
new_shape
=
n
->
get_shape
();
auto
new_shape
=
n
.
get_shape
();
new_shape
[
0
]
=
n
->
get_shape
()[
1
];
new_shape
[
0
]
=
n
.
get_shape
()[
1
];
new_shape
[
1
]
=
n
->
get_shape
()[
0
];
new_shape
[
1
]
=
n
.
get_shape
()[
0
];
return
make_shared
<
op
::
Reshape
>
(
n
,
ax_order
,
new_shape
);
return
make_shared
<
op
::
Reshape
>
(
n
,
ax_order
,
new_shape
);
};
};
...
@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
...
@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
return
backward_delta_out_pad_above
;
return
backward_delta_out_pad_above
;
}
}
const
string
op
::
ConvolutionBackpropFilters
::
type_name
{
"ConvolutionBackpropFilters"
};
op
::
ConvolutionBackpropFilters
::
ConvolutionBackpropFilters
(
op
::
ConvolutionBackpropFilters
::
ConvolutionBackpropFilters
(
const
shared_ptr
<
Node
>&
data_batch
,
const
Output
<
Node
>&
data_batch
,
const
Shape
&
filters_shape
,
const
Shape
&
filters_shape
,
const
shared_ptr
<
Node
>&
output_delta
,
const
Output
<
Node
>&
output_delta
,
const
Strides
&
window_movement_strides_forward
,
const
Strides
&
window_movement_strides_forward
,
const
Strides
&
window_dilation_strides_forward
,
const
Strides
&
window_dilation_strides_forward
,
const
CoordinateDiff
&
padding_below_forward
,
const
CoordinateDiff
&
padding_below_forward
,
const
CoordinateDiff
&
padding_above_forward
,
const
CoordinateDiff
&
padding_above_forward
,
const
Strides
&
data_dilation_strides_forward
)
const
Strides
&
data_dilation_strides_forward
)
:
Op
(
"ConvolutionBackpropFilters"
,
check_single_output_args
({
data_batch
,
output_delta
})
)
:
Op
(
{
data_batch
,
output_delta
}
)
,
m_filters_shape
(
filters_shape
)
,
m_filters_shape
(
filters_shape
)
,
m_window_movement_strides_forward
(
window_movement_strides_forward
)
,
m_window_movement_strides_forward
(
window_movement_strides_forward
)
,
m_window_dilation_strides_forward
(
window_dilation_strides_forward
)
,
m_window_dilation_strides_forward
(
window_dilation_strides_forward
)
...
...
src/ngraph/op/convolution.hpp
View file @
0552fd84
This diff is collapsed.
Click to expand it.
src/ngraph/op/cos.cpp
View file @
0552fd84
...
@@ -22,8 +22,10 @@
...
@@ -22,8 +22,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Cos
::
Cos
(
const
shared_ptr
<
Node
>&
arg
)
const
string
op
::
Cos
::
type_name
{
"Cos"
};
:
UnaryElementwiseArithmetic
(
"Cos"
,
arg
)
op
::
Cos
::
Cos
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/cos.hpp
View file @
0552fd84
...
@@ -26,10 +26,15 @@ namespace ngraph
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Cos
:
public
util
::
UnaryElementwiseArithmetic
class
Cos
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a cosine operation.
Cos
()
=
default
;
/// \brief Constructs a cosine operation.
/// \brief Constructs a cosine operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Cos
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Cos
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/cosh.cpp
View file @
0552fd84
...
@@ -21,8 +21,10 @@
...
@@ -21,8 +21,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Cosh
::
Cosh
(
const
shared_ptr
<
Node
>&
arg
)
const
string
op
::
Cosh
::
type_name
{
"Cosh"
};
:
UnaryElementwiseArithmetic
(
"Cosh"
,
arg
)
op
::
Cosh
::
Cosh
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/cosh.hpp
View file @
0552fd84
...
@@ -26,10 +26,15 @@ namespace ngraph
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Cosh
:
public
util
::
UnaryElementwiseArithmetic
class
Cosh
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a hyperbolic cosine operation.
Cosh
()
=
default
;
/// \brief Constructs a hyperbolic cosine operation.
/// \brief Constructs a hyperbolic cosine operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Cosh
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Cosh
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/dequantize.cpp
View file @
0552fd84
...
@@ -20,13 +20,15 @@
...
@@ -20,13 +20,15 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Dequantize
::
Dequantize
(
const
shared_ptr
<
Node
>&
input
,
const
string
op
::
Dequantize
::
type_name
{
"Dequantize"
};
const
shared_ptr
<
Node
>&
scale
,
const
shared_ptr
<
Node
>&
zero_point
,
op
::
Dequantize
::
Dequantize
(
const
Output
<
Node
>&
input
,
const
Output
<
Node
>&
scale
,
const
Output
<
Node
>&
zero_point
,
const
element
::
Type
&
type
,
const
element
::
Type
&
type
,
const
AxisSet
&
axes
)
const
AxisSet
&
axes
)
:
Op
(
"Dequantize"
,
check_single_output_args
({
input
,
scale
,
zero_point
})
)
:
Op
(
{
input
,
scale
,
zero_point
}
)
,
m_type
(
type
)
,
m_type
(
type
)
,
m_axes
(
axes
)
,
m_axes
(
axes
)
{
{
...
...
src/ngraph/op/dequantize.hpp
View file @
0552fd84
...
@@ -30,31 +30,40 @@ namespace ngraph
...
@@ -30,31 +30,40 @@ namespace ngraph
class
Dequantize
:
public
ngraph
::
op
::
Op
class
Dequantize
:
public
ngraph
::
op
::
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a Dequantize operation
Dequantize
()
=
default
;
/// \brief Constructs a Dequantize operation
/// \brief Constructs a Dequantize operation
/// \param input quantized input
/// \param input quantized input
/// \param scale scale used for mapping
/// \param scale scale used for mapping
/// \param zero_point zero point used for mapping
/// \param zero_point zero point used for mapping
/// \param type output element type
/// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified
/// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize
(
const
std
::
shared_ptr
<
Node
>&
input
,
Dequantize
(
const
Output
<
Node
>&
input
,
const
std
::
shared_ptr
<
Node
>&
scale
,
const
Output
<
Node
>&
scale
,
const
std
::
shared_ptr
<
Node
>&
zero_point
,
const
Output
<
Node
>&
zero_point
,
const
ngraph
::
element
::
Type
&
type
,
const
element
::
Type
&
type
,
const
ngraph
::
AxisSet
&
axes
);
const
AxisSet
&
axes
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
ngraph
::
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
const
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
void
set_axes
(
const
AxisSet
&
axes
)
{
m_axes
=
axes
;
}
const
element
::
Type
&
get_type
()
const
{
return
m_type
;
}
void
set_type
(
const
element
::
Type
&
type
)
{
m_type
=
type
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
private
:
private
:
ngraph
::
element
::
Type
m_type
;
element
::
Type
m_type
;
ngraph
::
AxisSet
m_axes
;
AxisSet
m_axes
;
};
};
}
}
}
}
src/ngraph/op/divide.cpp
View file @
0552fd84
...
@@ -21,20 +21,21 @@
...
@@ -21,20 +21,21 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Divide
::
Divide
(
const
shared_ptr
<
Node
>&
arg0
,
const
string
op
::
Divide
::
type_name
{
"Divide"
};
const
shared_ptr
<
Node
>&
arg1
,
op
::
Divide
::
Divide
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Divide"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
,
m_pythondiv
(
true
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
op
::
Divide
::
Divide
(
const
shared_ptr
<
Node
>&
arg0
,
op
::
Divide
::
Divide
(
const
Output
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
bool
pythondiv
,
bool
pythondiv
,
const
AutoBroadcastSpec
&
autob
)
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
"Divide"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
,
m_pythondiv
(
pythondiv
)
,
m_pythondiv
(
pythondiv
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
...
@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints
.
add_delta
(
y
,
-
delta
*
shared_from_this
()
/
y
);
adjoints
.
add_delta
(
y
,
-
delta
*
shared_from_this
()
/
y
);
}
}
shared_ptr
<
Node
>
ngraph
::
operator
/
(
const
shared_ptr
<
Node
>
arg0
,
const
shared_ptr
<
Node
>
arg1
)
shared_ptr
<
Node
>
ngraph
::
operator
/
(
const
Output
<
Node
>
arg0
,
const
Output
<
Node
>
arg1
)
{
{
return
make_shared
<
op
::
Divide
>
(
arg0
,
arg1
);
return
make_shared
<
op
::
Divide
>
(
arg0
,
arg1
);
}
}
src/ngraph/op/divide.hpp
View file @
0552fd84
...
@@ -26,14 +26,19 @@ namespace ngraph
...
@@ -26,14 +26,19 @@ namespace ngraph
class
Divide
:
public
util
::
BinaryElementwiseArithmetic
class
Divide
:
public
util
::
BinaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a division operation.
Divide
()
=
default
;
/// \brief Constructs a division operation.
/// \brief Constructs a division operation.
///
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param pythondiv Use Python style rounding for integral type
/// \param autob Auto broadcast specification
/// \param autob Auto broadcast specification
Divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
Divide
(
const
Output
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
bool
pythondiv
,
bool
pythondiv
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
...
@@ -42,11 +47,12 @@ namespace ngraph
...
@@ -42,11 +47,12 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
/// \param autob Auto broadcast specification
Divide
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
Divide
(
const
Output
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
bool
is_pythondiv
()
const
{
return
m_pythondiv
;
}
bool
is_pythondiv
()
const
{
return
m_pythondiv
;
}
void
set_is_pythondiv
(
bool
pythondiv
)
{
m_pythondiv
=
pythondiv
;
}
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
@@ -54,10 +60,10 @@ namespace ngraph
...
@@ -54,10 +60,10 @@ namespace ngraph
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
protected
:
protected
:
bool
m_pythondiv
;
bool
m_pythondiv
{
true
}
;
};
};
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
operator
/
(
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg0
,
std
::
shared_ptr
<
ngraph
::
Node
>
operator
/
(
const
Output
<
ngraph
::
Node
>
arg0
,
const
std
::
shared_ptr
<
ngraph
::
Node
>
arg1
);
const
Output
<
ngraph
::
Node
>
arg1
);
}
}
src/ngraph/op/dot.cpp
View file @
0552fd84
...
@@ -29,16 +29,18 @@
...
@@ -29,16 +29,18 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Dot
::
Dot
(
const
shared_ptr
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
)
const
string
op
::
Dot
::
type_name
{
"Dot"
};
op
::
Dot
::
Dot
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
)
:
Dot
(
arg0
,
arg1
,
0
,
false
)
:
Dot
(
arg0
,
arg1
,
0
,
false
)
{
{
}
}
op
::
Dot
::
Dot
(
const
shared_ptr
<
Node
>&
arg0
,
op
::
Dot
::
Dot
(
const
Output
<
Node
>&
arg0
,
const
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
size_t
reduction_axes_count
,
size_t
reduction_axes_count
,
bool
has_reduction_axes_count
)
bool
has_reduction_axes_count
)
:
Op
(
"Dot"
,
check_single_output_args
({
arg0
,
arg1
})
)
:
Op
(
{
arg0
,
arg1
}
)
,
m_reduction_axes_count
(
reduction_axes_count
)
,
m_reduction_axes_count
(
reduction_axes_count
)
,
m_has_reduction_axes_count
(
has_reduction_axes_count
)
,
m_has_reduction_axes_count
(
has_reduction_axes_count
)
{
{
...
@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
...
@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
shared_ptr
<
op
::
Reshape
>
make_reshape_axes_to_front
(
const
shared_ptr
<
Node
>&
n
,
shared_ptr
<
op
::
Reshape
>
make_reshape_axes_to_front
(
const
Output
<
Node
>&
n
,
const
Shape
&
front_shape
,
const
Shape
&
front_shape
,
const
Shape
&
back_shape
)
const
Shape
&
back_shape
)
{
{
...
...
src/ngraph/op/dot.hpp
View file @
0552fd84
...
@@ -28,13 +28,18 @@ namespace ngraph
...
@@ -28,13 +28,18 @@ namespace ngraph
class
Dot
:
public
Op
class
Dot
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a dot product operation.
Dot
()
=
default
;
/// \brief Constructs a dot product operation.
/// \brief Constructs a dot product operation.
///
///
/// \param arg0 The node producing the first argument.
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
/// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot.
/// \param reduction_axes_count The number of axes to dot.
Dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
Dot
(
const
Output
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
size_t
reduction_axes_count
,
size_t
reduction_axes_count
,
bool
has_reduction_axes_count
=
true
);
bool
has_reduction_axes_count
=
true
);
...
@@ -48,11 +53,20 @@ namespace ngraph
...
@@ -48,11 +53,20 @@ namespace ngraph
///
///
/// \param arg0 The node producing the first argument.
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
/// \param arg1 The node producing the second argument.
Dot
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
);
Dot
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
size_t
get_reduction_axes_count
()
const
{
return
m_reduction_axes_count
;
}
size_t
get_reduction_axes_count
()
const
{
return
m_reduction_axes_count
;
}
void
get_reduction_axes_count
(
size_t
reduction_axes_count
)
{
m_reduction_axes_count
=
reduction_axes_count
;
}
bool
get_has_reduction_axes_count
()
const
{
return
m_has_reduction_axes_count
;
}
void
set_has_reduction_axes_count
(
bool
has_reduction_axes_count
)
{
m_has_reduction_axes_count
=
has_reduction_axes_count
;
}
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
{
{
...
...
src/ngraph/op/embedding_lookup.cpp
View file @
0552fd84
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
const
string
op
::
EmbeddingLookup
::
type_name
{
"EmbeddingLookup"
};
void
op
::
EmbeddingLookup
::
validate_and_infer_types
()
void
op
::
EmbeddingLookup
::
validate_and_infer_types
()
{
{
element
::
Type
result_et
=
get_input_element_type
(
1
);
element
::
Type
result_et
=
get_input_element_type
(
1
);
...
...
src/ngraph/op/embedding_lookup.hpp
View file @
0552fd84
...
@@ -28,6 +28,11 @@ namespace ngraph
...
@@ -28,6 +28,11 @@ namespace ngraph
class
EmbeddingLookup
:
public
Op
class
EmbeddingLookup
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a EmbeddingLookup operation.
EmbeddingLookup
()
=
default
;
/// \brief Constructs a EmbeddingLookup operation.
/// \brief Constructs a EmbeddingLookup operation.
///
///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
...
@@ -36,8 +41,8 @@ namespace ngraph
...
@@ -36,8 +41,8 @@ namespace ngraph
/// \param data The input indices for tokens to be translated into embeddings
/// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N
/// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup
(
const
std
::
shared_ptr
<
Node
>&
data
,
const
std
::
shared_ptr
<
Node
>&
weights
)
EmbeddingLookup
(
const
Output
<
Node
>&
data
,
const
Output
<
Node
>&
weights
)
:
Op
(
"EmbeddingLookup"
,
check_single_output_args
({
data
,
weights
})
)
:
Op
(
{
data
,
weights
}
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/equal.cpp
View file @
0552fd84
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Equal
::
Equal
(
const
shared_ptr
<
Node
>&
arg0
,
const
string
op
::
Equal
::
type_name
{
"Equal"
};
const
shared_ptr
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
op
::
Equal
::
Equal
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseComparison
(
"Equal"
,
arg0
,
arg1
,
autob
)
:
BinaryElementwiseComparison
(
arg0
,
arg1
,
autob
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/equal.hpp
View file @
0552fd84
...
@@ -40,13 +40,18 @@ namespace ngraph
...
@@ -40,13 +40,18 @@ namespace ngraph
class
Equal
:
public
util
::
BinaryElementwiseComparison
class
Equal
:
public
util
::
BinaryElementwiseComparison
{
{
public
:
public
:
/// \brief Constructs an is-equal operation.
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an equal operation.
Equal
()
=
default
;
/// \brief Constructs an equal operation.
///
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
/// \param autob Auto broadcast specification
Equal
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
Equal
(
const
Output
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/erf.cpp
View file @
0552fd84
...
@@ -21,14 +21,16 @@
...
@@ -21,14 +21,16 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
const
string
op
::
Erf
::
type_name
{
"Erf"
};
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
Erf
>
(
new_args
.
at
(
0
));
return
make_shared
<
Erf
>
(
new_args
.
at
(
0
));
}
}
op
::
Erf
::
Erf
(
shared_ptr
<
Node
>
arg
)
op
::
Erf
::
Erf
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
"Erf"
,
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
src/ngraph/op/erf.hpp
View file @
0552fd84
...
@@ -27,7 +27,11 @@ namespace ngraph
...
@@ -27,7 +27,11 @@ namespace ngraph
class
Erf
:
public
util
::
UnaryElementwiseArithmetic
class
Erf
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
Erf
(
std
::
shared_ptr
<
Node
>
arg
);
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Erf
()
=
default
;
Erf
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/exp.cpp
View file @
0552fd84
...
@@ -20,8 +20,10 @@
...
@@ -20,8 +20,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Exp
::
Exp
(
const
shared_ptr
<
Node
>&
arg
)
const
string
op
::
Exp
::
type_name
{
"Exp"
};
:
UnaryElementwiseArithmetic
(
"Exp"
,
arg
)
op
::
Exp
::
Exp
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/exp.hpp
View file @
0552fd84
...
@@ -26,10 +26,15 @@ namespace ngraph
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Exp
:
public
util
::
UnaryElementwiseArithmetic
class
Exp
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs an exponential operation.
Exp
()
=
default
;
/// \brief Constructs an exponential operation.
/// \brief Constructs an exponential operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Exp
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Exp
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/floor.cpp
View file @
0552fd84
...
@@ -19,8 +19,10 @@
...
@@ -19,8 +19,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Floor
::
Floor
(
const
shared_ptr
<
Node
>&
arg
)
const
string
op
::
Floor
::
type_name
{
"Floor"
};
:
UnaryElementwiseArithmetic
(
"Floor"
,
arg
)
op
::
Floor
::
Floor
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/floor.hpp
View file @
0552fd84
...
@@ -26,10 +26,15 @@ namespace ngraph
...
@@ -26,10 +26,15 @@ namespace ngraph
class
Floor
:
public
util
::
UnaryElementwiseArithmetic
class
Floor
:
public
util
::
UnaryElementwiseArithmetic
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a floor operation.
Floor
()
=
default
;
/// \brief Constructs a floor operation.
/// \brief Constructs a floor operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Floor
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Floor
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/reshape.cpp
View file @
0552fd84
...
@@ -24,10 +24,12 @@
...
@@ -24,10 +24,12 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Reshape
::
Reshape
(
const
shared_ptr
<
Node
>&
arg
,
const
string
op
::
Reshape
::
type_name
{
"Reshape"
};
op
::
Reshape
::
Reshape
(
const
Output
<
Node
>&
arg
,
const
AxisVector
&
input_order
,
const
AxisVector
&
input_order
,
const
Shape
&
output_shape
)
const
Shape
&
output_shape
)
:
Op
(
"Reshape"
,
check_single_output_args
({
arg
})
)
:
Op
(
{
arg
}
)
,
m_input_order
(
input_order
)
,
m_input_order
(
input_order
)
,
m_output_shape
(
output_shape
)
,
m_output_shape
(
output_shape
)
{
{
...
...
src/ngraph/op/reshape.hpp
View file @
0552fd84
...
@@ -60,6 +60,11 @@ namespace ngraph
...
@@ -60,6 +60,11 @@ namespace ngraph
class
Reshape
:
public
Op
class
Reshape
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a reshape operation.
Reshape
()
=
default
;
/// \brief Constructs a reshape operation.
/// \brief Constructs a reshape operation.
///
///
/// \param arg The tensor to be reshaped.
/// \param arg The tensor to be reshaped.
...
@@ -67,7 +72,7 @@ namespace ngraph
...
@@ -67,7 +72,7 @@ namespace ngraph
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape
(
const
std
::
shared_ptr
<
Node
>&
arg
,
Reshape
(
const
Output
<
Node
>&
arg
,
const
AxisVector
&
input_order
,
const
AxisVector
&
input_order
,
const
Shape
&
output_shape
);
const
Shape
&
output_shape
);
...
@@ -78,15 +83,18 @@ namespace ngraph
...
@@ -78,15 +83,18 @@ namespace ngraph
/// \return The order in which to iterate over input axes.
/// \return The order in which to iterate over input axes.
const
AxisVector
&
get_input_order
()
const
{
return
m_input_order
;
}
const
AxisVector
&
get_input_order
()
const
{
return
m_input_order
;
}
void
set_input_order
(
const
AxisVector
&
input_order
)
{
m_input_order
=
input_order
;
}
/// \return The shape of the output tensor.
/// \return The shape of the output tensor.
const
Shape
&
get_output_shape
()
const
{
return
m_output_shape
;
}
const
Shape
&
get_output_shape
()
const
{
return
m_output_shape
;
}
void
set_output_shape
(
const
Shape
&
output_shape
)
{
m_output_shape
=
output_shape
;
}
bool
get_is_transpose
()
const
{
return
m_is_transpose
;
}
bool
get_is_transpose
()
const
{
return
m_is_transpose
;
}
void
set_is_transpose
(
bool
is_transpose
)
{
m_is_transpose
=
is_transpose
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
const
AxisVector
m_input_order
;
AxisVector
m_input_order
;
const
Shape
m_output_shape
;
Shape
m_output_shape
;
bool
m_is_transpose
{
false
};
bool
m_is_transpose
{
false
};
};
};
}
}
...
...
src/ngraph/op/result.cpp
View file @
0552fd84
...
@@ -24,8 +24,10 @@
...
@@ -24,8 +24,10 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Result
::
Result
(
const
shared_ptr
<
Node
>&
arg
,
bool
needs_default_layout
)
const
string
op
::
Result
::
type_name
{
"Result"
};
:
Op
(
"Result"
,
check_single_output_args
({
arg
}))
op
::
Result
::
Result
(
const
Output
<
Node
>&
arg
,
bool
needs_default_layout
)
:
Op
({
arg
})
,
m_needs_default_layout
(
needs_default_layout
)
,
m_needs_default_layout
(
needs_default_layout
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/result.hpp
View file @
0552fd84
...
@@ -27,10 +27,15 @@ namespace ngraph
...
@@ -27,10 +27,15 @@ namespace ngraph
class
Result
:
public
Op
class
Result
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Allows a value to be used as a function result.
Result
()
=
default
;
/// \brief Allows a value to be used as a function result.
/// \brief Allows a value to be used as a function result.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Result
(
const
std
::
shared_ptr
<
Node
>&
arg
,
bool
needs_default_layout
=
false
);
Result
(
const
Output
<
Node
>&
arg
,
bool
needs_default_layout
=
false
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment