Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6b5056e8
Unverified
Commit
6b5056e8
authored
Aug 05, 2019
by
Scott Cyphers
Committed by
GitHub
Aug 05, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Convert remaining op/ ops to Output<Node> constructors (#3373)
parent
a92b30be
Hide whitespace changes
Inline
Side-by-side
Showing
47 changed files
with
143 additions
and
90 deletions
+143
-90
pad.cpp
src/ngraph/op/pad.cpp
+3
-3
pad.hpp
src/ngraph/op/pad.hpp
+14
-2
parameter.hpp
src/ngraph/op/parameter.hpp
+1
-0
passthrough.cpp
src/ngraph/op/passthrough.cpp
+15
-0
passthrough.hpp
src/ngraph/op/passthrough.hpp
+7
-0
power.cpp
src/ngraph/op/power.cpp
+1
-3
power.hpp
src/ngraph/op/power.hpp
+3
-2
product.cpp
src/ngraph/op/product.cpp
+0
-4
product.hpp
src/ngraph/op/product.hpp
+1
-1
quantize.hpp
src/ngraph/op/quantize.hpp
+2
-0
quantized_convolution.hpp
src/ngraph/op/quantized_convolution.hpp
+2
-0
relu.cpp
src/ngraph/op/relu.cpp
+1
-1
relu.hpp
src/ngraph/op/relu.hpp
+2
-1
replace_slice.hpp
src/ngraph/op/replace_slice.hpp
+7
-0
reverse.cpp
src/ngraph/op/reverse.cpp
+3
-3
reverse.hpp
src/ngraph/op/reverse.hpp
+8
-2
reverse_sequence.cpp
src/ngraph/op/reverse_sequence.cpp
+6
-6
reverse_sequence.hpp
src/ngraph/op/reverse_sequence.hpp
+5
-2
scatter_add.hpp
src/ngraph/op/scatter_add.hpp
+5
-4
scatter_nd_add.hpp
src/ngraph/op/scatter_nd_add.hpp
+5
-4
select.cpp
src/ngraph/op/select.cpp
+5
-5
sigmoid.cpp
src/ngraph/op/sigmoid.cpp
+4
-4
sigmoid.hpp
src/ngraph/op/sigmoid.hpp
+4
-2
sign.cpp
src/ngraph/op/sign.cpp
+1
-1
sign.hpp
src/ngraph/op/sign.hpp
+2
-1
sin.cpp
src/ngraph/op/sin.cpp
+2
-2
sin.hpp
src/ngraph/op/sin.hpp
+2
-1
sinh.cpp
src/ngraph/op/sinh.cpp
+1
-1
sinh.hpp
src/ngraph/op/sinh.hpp
+2
-1
slice.cpp
src/ngraph/op/slice.cpp
+1
-5
slice.hpp
src/ngraph/op/slice.hpp
+1
-1
softmax.cpp
src/ngraph/op/softmax.cpp
+2
-2
softmax.hpp
src/ngraph/op/softmax.hpp
+3
-1
sqrt.cpp
src/ngraph/op/sqrt.cpp
+2
-2
sqrt.hpp
src/ngraph/op/sqrt.hpp
+2
-1
stop_gradient.cpp
src/ngraph/op/stop_gradient.cpp
+1
-1
stop_gradient.hpp
src/ngraph/op/stop_gradient.hpp
+2
-1
subtract.cpp
src/ngraph/op/subtract.cpp
+2
-2
sum.cpp
src/ngraph/op/sum.cpp
+0
-4
sum.hpp
src/ngraph/op/sum.hpp
+1
-1
tan.cpp
src/ngraph/op/tan.cpp
+2
-2
tan.hpp
src/ngraph/op/tan.hpp
+2
-1
tanh.cpp
src/ngraph/op/tanh.cpp
+2
-2
tanh.hpp
src/ngraph/op/tanh.hpp
+2
-1
topk.cpp
src/ngraph/op/topk.cpp
+2
-5
topk.hpp
src/ngraph/op/topk.hpp
+1
-1
serializer.cpp
src/ngraph/serializer.cpp
+1
-1
No files found.
src/ngraph/op/pad.cpp
View file @
6b5056e8
...
@@ -23,12 +23,12 @@ using namespace ngraph;
...
@@ -23,12 +23,12 @@ using namespace ngraph;
const
string
op
::
Pad
::
type_name
{
"Pad"
};
const
string
op
::
Pad
::
type_name
{
"Pad"
};
op
::
Pad
::
Pad
(
const
shared_ptr
<
Node
>&
arg
,
op
::
Pad
::
Pad
(
const
Output
<
Node
>&
arg
,
const
shared_ptr
<
Node
>&
arg_pad_value
,
const
Output
<
Node
>&
arg_pad_value
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_above
,
const
CoordinateDiff
&
padding_above
,
PadMode
pad_mode
)
PadMode
pad_mode
)
:
Op
(
check_single_output_args
({
arg
,
arg_pad_value
})
)
:
Op
(
{
arg
,
arg_pad_value
}
)
,
m_padding_below
(
padding_below
)
,
m_padding_below
(
padding_below
)
,
m_padding_above
(
padding_above
)
,
m_padding_above
(
padding_above
)
,
m_padding_interior_fake
(
padding_below
.
size
())
,
m_padding_interior_fake
(
padding_below
.
size
())
...
...
src/ngraph/op/pad.hpp
View file @
6b5056e8
...
@@ -32,14 +32,16 @@ namespace ngraph
...
@@ -32,14 +32,16 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a generic padding operation.
/// \brief Constructs a generic padding operation.
Pad
()
=
default
;
/// \brief Constructs a generic padding operation.
///
///
/// \param arg The node producing input tensor to be padded.
/// \param arg The node producing input tensor to be padded.
/// \param arg_pad_value The node producing the scalar value to be inserted for padding.
/// \param arg_pad_value The node producing the scalar value to be inserted for padding.
/// \param padding_below The padding-below widths.
/// \param padding_below The padding-below widths.
/// \param padding_above The padding-above widths.
/// \param padding_above The padding-above widths.
/// \param pad_mode The padding mode: CONSTANT(default), EDGE, REFLECT or SYMMETRIC.
/// \param pad_mode The padding mode: CONSTANT(default), EDGE, REFLECT or SYMMETRIC.
Pad
(
const
std
::
shared_ptr
<
Node
>&
arg
,
Pad
(
const
Output
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
arg_pad_value
,
const
Output
<
Node
>&
arg_pad_value
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_above
,
const
CoordinateDiff
&
padding_above
,
PadMode
pad_mode
=
PadMode
::
CONSTANT
);
PadMode
pad_mode
=
PadMode
::
CONSTANT
);
...
@@ -49,14 +51,24 @@ namespace ngraph
...
@@ -49,14 +51,24 @@ namespace ngraph
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
/// \return The padding-below sizes.
/// \return The padding-below sizes.
const
CoordinateDiff
&
get_padding_below
()
const
{
return
m_padding_below
;
}
const
CoordinateDiff
&
get_padding_below
()
const
{
return
m_padding_below
;
}
void
set_padding_below
(
const
CoordinateDiff
&
padding_below
)
{
m_padding_below
=
padding_below
;
}
/// \return The padding-above sizes.
/// \return The padding-above sizes.
const
CoordinateDiff
&
get_padding_above
()
const
{
return
m_padding_above
;
}
const
CoordinateDiff
&
get_padding_above
()
const
{
return
m_padding_above
;
}
void
set_padding_above
(
const
CoordinateDiff
&
padding_above
)
{
m_padding_below
=
padding_above
;
}
/// \brief DEPRECATED. This is just a stub for backends that used to implement the
/// \brief DEPRECATED. This is just a stub for backends that used to implement the
/// interior padding feature, which is no longer supported.
/// interior padding feature, which is no longer supported.
/// \return Returns a shape full of zeros, with the same rank as get_padding_below().
/// \return Returns a shape full of zeros, with the same rank as get_padding_below().
const
Shape
&
get_padding_interior
()
const
{
return
m_padding_interior_fake
;
}
const
Shape
&
get_padding_interior
()
const
{
return
m_padding_interior_fake
;
}
/// \return The padding mode.
/// \return The padding mode.
PadMode
get_pad_mode
()
const
{
return
m_pad_mode
;
}
PadMode
get_pad_mode
()
const
{
return
m_pad_mode
;
}
void
set_pad_mode
(
PadMode
pad_mode
)
{
m_pad_mode
=
pad_mode
;
}
/// \return The default value for Pad.
/// \return The default value for Pad.
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
override
;
...
...
src/ngraph/op/parameter.hpp
View file @
6b5056e8
...
@@ -38,6 +38,7 @@ namespace ngraph
...
@@ -38,6 +38,7 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Parameter
()
=
default
;
/// \brief Constructions a tensor-typed parameter node.
/// \brief Constructions a tensor-typed parameter node.
///
///
/// \param element_type The element type of the parameter.
/// \param element_type The element type of the parameter.
...
...
src/ngraph/op/passthrough.cpp
View file @
6b5056e8
...
@@ -38,6 +38,21 @@ ngraph::op::Passthrough::Passthrough(const std::string& logical_type,
...
@@ -38,6 +38,21 @@ ngraph::op::Passthrough::Passthrough(const std::string& logical_type,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
ngraph
::
op
::
Passthrough
::
Passthrough
(
const
std
::
string
&
logical_type
,
const
std
::
string
&
language
,
const
std
::
string
&
function
,
const
OutputVector
&
args
,
std
::
vector
<
std
::
tuple
<
element
::
Type
,
PartialShape
>>
outputs
)
:
Op
{
args
}
,
m_logical_type
{
logical_type
}
,
m_language
{
language
}
,
m_function
{
function
}
,
m_output_shapes
{
std
::
move
(
outputs
)}
{
set_output_size
(
m_output_shapes
.
size
());
constructor_validate_and_infer_types
();
}
void
ngraph
::
op
::
Passthrough
::
validate_and_infer_types
()
void
ngraph
::
op
::
Passthrough
::
validate_and_infer_types
()
{
{
// N.B. It would be useful to have the backend deduce the output
// N.B. It would be useful to have the backend deduce the output
...
...
src/ngraph/op/passthrough.hpp
View file @
6b5056e8
...
@@ -41,12 +41,19 @@ public:
...
@@ -41,12 +41,19 @@ public:
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Passthrough
()
=
default
;
Passthrough
(
const
std
::
string
&
logical_type
,
// aka "What this operation is doing"
Passthrough
(
const
std
::
string
&
logical_type
,
// aka "What this operation is doing"
const
std
::
string
&
language
,
// The language the implementation is written in
const
std
::
string
&
language
,
// The language the implementation is written in
const
std
::
string
&
function
,
// The operation implementation
const
std
::
string
&
function
,
// The operation implementation
const
NodeVector
&
args
,
const
NodeVector
&
args
,
std
::
vector
<
std
::
tuple
<
element
::
Type
,
PartialShape
>>
outputs
);
std
::
vector
<
std
::
tuple
<
element
::
Type
,
PartialShape
>>
outputs
);
Passthrough
(
const
std
::
string
&
logical_type
,
// aka "What this operation is doing"
const
std
::
string
&
language
,
// The language the implementation is written in
const
std
::
string
&
function
,
// The operation implementation
const
OutputVector
&
args
,
std
::
vector
<
std
::
tuple
<
element
::
Type
,
PartialShape
>>
outputs
);
void
validate_and_infer_types
()
final
;
void
validate_and_infer_types
()
final
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
final
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
final
;
...
...
src/ngraph/op/power.cpp
View file @
6b5056e8
...
@@ -24,9 +24,7 @@ using namespace ngraph;
...
@@ -24,9 +24,7 @@ using namespace ngraph;
const
string
op
::
Power
::
type_name
{
"Power"
};
const
string
op
::
Power
::
type_name
{
"Power"
};
op
::
Power
::
Power
(
const
shared_ptr
<
Node
>&
arg0
,
op
::
Power
::
Power
(
const
Output
<
Node
>&
arg0
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
const
shared_ptr
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
:
BinaryElementwiseArithmetic
(
arg0
,
arg1
,
autob
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/power.hpp
View file @
6b5056e8
...
@@ -42,13 +42,14 @@ namespace ngraph
...
@@ -42,13 +42,14 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Power
()
=
default
;
/// \brief Constructs an exponentiation operation.
/// \brief Constructs an exponentiation operation.
///
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
/// \param autob Auto broadcast specification
Power
(
const
std
::
shared_ptr
<
Node
>&
arg0
,
Power
(
const
Output
<
Node
>&
arg0
,
const
std
::
shared_ptr
<
Node
>&
arg1
,
const
Output
<
Node
>&
arg1
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/product.cpp
View file @
6b5056e8
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
Product
::
type_name
{
"Product"
};
const
string
op
::
Product
::
type_name
{
"Product"
};
op
::
Product
::
Product
()
{
}
op
::
Product
::
Product
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
op
::
Product
::
Product
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
ArithmeticReduction
(
arg
,
reduction_axes
)
:
ArithmeticReduction
(
arg
,
reduction_axes
)
{
{
...
...
src/ngraph/op/product.hpp
View file @
6b5056e8
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a product reduction operation.
/// \brief Constructs a product reduction operation.
Product
();
Product
()
=
default
;
/// \brief Constructs a product reduction operation.
/// \brief Constructs a product reduction operation.
///
///
/// \param arg The tensor to be reduced.
/// \param arg The tensor to be reduced.
...
...
src/ngraph/op/quantize.hpp
View file @
6b5056e8
...
@@ -92,6 +92,8 @@ namespace ngraph
...
@@ -92,6 +92,8 @@ namespace ngraph
const
ngraph
::
AxisSet
&
axes
,
const
ngraph
::
AxisSet
&
axes
,
RoundMode
round_mode
);
RoundMode
round_mode
);
Quantize
()
=
default
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/quantized_convolution.hpp
View file @
6b5056e8
...
@@ -66,6 +66,8 @@ namespace ngraph
...
@@ -66,6 +66,8 @@ namespace ngraph
const
ngraph
::
AxisSet
&
filter_axes
=
ngraph
::
AxisSet
{},
const
ngraph
::
AxisSet
&
filter_axes
=
ngraph
::
AxisSet
{},
const
ngraph
::
AxisSet
&
output_axes
=
ngraph
::
AxisSet
{});
const
ngraph
::
AxisSet
&
output_axes
=
ngraph
::
AxisSet
{});
QuantizedConvolution
()
=
default
;
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
const
Strides
&
get_window_movement_strides
()
const
{
return
m_window_movement_strides
;
}
const
Strides
&
get_window_dilation_strides
()
const
{
return
m_window_dilation_strides
;
}
const
Strides
&
get_window_dilation_strides
()
const
{
return
m_window_dilation_strides
;
}
const
CoordinateDiff
&
get_padding_below
()
const
{
return
m_padding_below
;
}
const
CoordinateDiff
&
get_padding_below
()
const
{
return
m_padding_below
;
}
...
...
src/ngraph/op/relu.cpp
View file @
6b5056e8
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
const
string
op
::
Relu
::
type_name
{
"Relu"
};
const
string
op
::
Relu
::
type_name
{
"Relu"
};
const
string
op
::
ReluBackprop
::
type_name
{
"ReluBackprop"
};
const
string
op
::
ReluBackprop
::
type_name
{
"ReluBackprop"
};
op
::
Relu
::
Relu
(
shared_ptr
<
Node
>
arg
)
op
::
Relu
::
Relu
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/relu.hpp
View file @
6b5056e8
...
@@ -36,10 +36,11 @@ namespace ngraph
...
@@ -36,10 +36,11 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Relu
()
=
default
;
/// \brief Constructs a Relu operation.
/// \brief Constructs a Relu operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Relu
(
std
::
shared_ptr
<
ngraph
::
Node
>
arg
);
Relu
(
const
Output
<
ngraph
::
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/replace_slice.hpp
View file @
6b5056e8
...
@@ -53,6 +53,7 @@ namespace ngraph
...
@@ -53,6 +53,7 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
ReplaceSlice
()
=
default
;
/// \brief Constructs a tensor slice replacement operation.
/// \brief Constructs a tensor slice replacement operation.
///
///
/// \param arg0 The tensor to overwrite into.
/// \param arg0 The tensor to overwrite into.
...
@@ -85,10 +86,16 @@ namespace ngraph
...
@@ -85,10 +86,16 @@ namespace ngraph
/// \return The inclusive lower-bound coordinates.
/// \return The inclusive lower-bound coordinates.
const
Coordinate
&
get_lower_bounds
()
const
{
return
m_lower_bounds
;
}
const
Coordinate
&
get_lower_bounds
()
const
{
return
m_lower_bounds
;
}
void
set_lower_bounds
(
const
Coordinate
&
lower_bounds
)
{
m_lower_bounds
=
lower_bounds
;
}
/// \return The exclusive upper-bound coordinates.
/// \return The exclusive upper-bound coordinates.
const
Coordinate
&
get_upper_bounds
()
const
{
return
m_upper_bounds
;
}
const
Coordinate
&
get_upper_bounds
()
const
{
return
m_upper_bounds
;
}
void
set_uppper_bounds
(
const
Coordinate
&
upper_bounds
)
{
m_upper_bounds
=
upper_bounds
;
}
/// \return The slicing strides.
/// \return The slicing strides.
const
Strides
&
get_strides
()
const
{
return
m_strides
;
}
const
Strides
&
get_strides
()
const
{
return
m_strides
;
}
void
set_strides
(
const
Strides
&
strides
)
{
m_strides
=
strides
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
...
...
src/ngraph/op/reverse.cpp
View file @
6b5056e8
...
@@ -25,8 +25,8 @@ using namespace ngraph;
...
@@ -25,8 +25,8 @@ using namespace ngraph;
const
string
op
::
Reverse
::
type_name
{
"Reverse"
};
const
string
op
::
Reverse
::
type_name
{
"Reverse"
};
op
::
Reverse
::
Reverse
(
const
shared_ptr
<
Node
>&
arg
,
const
AxisSet
&
reversed_axes
)
op
::
Reverse
::
Reverse
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reversed_axes
)
:
Op
(
check_single_output_args
({
arg
})
)
:
Op
(
{
arg
}
)
,
m_reversed_axes
(
reversed_axes
)
,
m_reversed_axes
(
reversed_axes
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -65,7 +65,7 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
...
@@ -65,7 +65,7 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Reverse
>
(
delta
,
m_reversed_axes
));
adjoints
.
add_delta
(
x
,
make_shared
<
op
::
Reverse
>
(
delta
,
m_reversed_axes
));
}
}
src/ngraph/op/reverse.hpp
View file @
6b5056e8
...
@@ -49,11 +49,12 @@ namespace ngraph
...
@@ -49,11 +49,12 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Reverse
()
=
default
;
/// \brief Constructs a reverse operation.
/// \brief Constructs a reverse operation.
///
///
/// \param arg The input tensor, some of whose axes are to be reversed.
/// \param arg The input tensor, some of whose axes are to be reversed.
/// \param reversed_axes The axes to reverse.
/// \param reversed_axes The axes to reverse.
Reverse
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
AxisSet
&
reversed_axes
);
Reverse
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reversed_axes
);
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
@@ -62,11 +63,16 @@ namespace ngraph
...
@@ -62,11 +63,16 @@ namespace ngraph
/// \return The set of axes to reverse.
/// \return The set of axes to reverse.
const
AxisSet
&
get_reversed_axes
()
const
{
return
m_reversed_axes
;
}
const
AxisSet
&
get_reversed_axes
()
const
{
return
m_reversed_axes
;
}
void
set_reversed_axes
(
const
AxisSet
&
reversed_axes
)
{
m_reversed_axes
=
reversed_axes
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
const
AxisSet
m_reversed_axes
;
AxisSet
m_reversed_axes
;
};
};
}
}
}
}
src/ngraph/op/reverse_sequence.cpp
View file @
6b5056e8
...
@@ -27,11 +27,11 @@ using namespace ngraph;
...
@@ -27,11 +27,11 @@ using namespace ngraph;
const
string
op
::
ReverseSequence
::
type_name
{
"ReverseSequence"
};
const
string
op
::
ReverseSequence
::
type_name
{
"ReverseSequence"
};
op
::
ReverseSequence
::
ReverseSequence
(
const
std
::
shared_ptr
<
Node
>
arg
,
op
::
ReverseSequence
::
ReverseSequence
(
const
Output
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>
seq_indices
,
const
Output
<
Node
>&
seq_indices
,
size_t
batch_axis
,
size_t
batch_axis
,
size_t
seq_axis
)
size_t
seq_axis
)
:
Op
(
check_single_output_args
({
arg
,
seq_indices
})
)
:
Op
(
{
arg
,
seq_indices
}
)
,
m_batch_axis
(
batch_axis
)
,
m_batch_axis
(
batch_axis
)
,
m_seq_axis
(
seq_axis
)
,
m_seq_axis
(
seq_axis
)
{
{
...
@@ -104,8 +104,8 @@ shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_a
...
@@ -104,8 +104,8 @@ shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_a
void
op
::
ReverseSequence
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
void
op
::
ReverseSequence
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
{
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
auto
rs_delta
=
auto
rs_delta
=
make_shared
<
ReverseSequence
>
(
make_shared
<
ReverseSequence
>
(
deltas
.
at
(
0
),
get_argument
(
1
),
m_batch_axis
,
m_seq_axis
);
deltas
.
at
(
0
),
input
(
1
).
get_source_output
(
),
m_batch_axis
,
m_seq_axis
);
adjoints
.
add_delta
(
x
,
rs_delta
);
adjoints
.
add_delta
(
x
,
rs_delta
);
}
}
src/ngraph/op/reverse_sequence.hpp
View file @
6b5056e8
...
@@ -28,11 +28,12 @@ namespace ngraph
...
@@ -28,11 +28,12 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
ReverseSequence
()
=
default
;
/// \brief Constructs an arcsin operation.
/// \brief Constructs an arcsin operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
ReverseSequence
(
const
std
::
shared_ptr
<
Node
>
arg
,
ReverseSequence
(
const
Output
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>
seq_lengths
,
const
Output
<
Node
>&
seq_lengths
,
size_t
batch_axis
,
size_t
batch_axis
,
size_t
seq_axis
);
size_t
seq_axis
);
...
@@ -42,7 +43,9 @@ namespace ngraph
...
@@ -42,7 +43,9 @@ namespace ngraph
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
size_t
get_batch_axis
()
const
{
return
m_batch_axis
;
}
size_t
get_batch_axis
()
const
{
return
m_batch_axis
;
}
void
set_batch_axis
(
size_t
batch_axis
)
{
m_batch_axis
=
batch_axis
;
}
size_t
get_sequence_axis
()
const
{
return
m_seq_axis
;
}
size_t
get_sequence_axis
()
const
{
return
m_seq_axis
;
}
void
set_sequence_axis
(
size_t
sequence_axis
)
{
m_seq_axis
=
sequence_axis
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
...
...
src/ngraph/op/scatter_add.hpp
View file @
6b5056e8
...
@@ -29,13 +29,14 @@ namespace ngraph
...
@@ -29,13 +29,14 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
ScatterAdd
()
=
default
;
/// \param inputs Tensor
/// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs
/// \param updates Tensor: Must have same type as inputs
ScatterAdd
(
const
std
::
shared_ptr
<
Node
>&
inputs
,
ScatterAdd
(
const
Output
<
Node
>&
inputs
,
const
std
::
shared_ptr
<
Node
>&
indices
,
const
Output
<
Node
>&
indices
,
const
std
::
shared_ptr
<
Node
>&
updates
)
const
Output
<
Node
>&
updates
)
:
Op
(
check_single_output_args
({
inputs
,
indices
,
updates
})
)
:
Op
(
{
inputs
,
indices
,
updates
}
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/scatter_nd_add.hpp
View file @
6b5056e8
...
@@ -29,13 +29,14 @@ namespace ngraph
...
@@ -29,13 +29,14 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
ScatterNDAdd
()
=
default
;
/// \param inputs Tensor
/// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs
/// \param updates Tensor: Must have same type as inputs
ScatterNDAdd
(
const
std
::
shared_ptr
<
Node
>&
inputs
,
ScatterNDAdd
(
const
Output
<
Node
>&
inputs
,
const
std
::
shared_ptr
<
Node
>&
indices
,
const
Output
<
Node
>&
indices
,
const
std
::
shared_ptr
<
Node
>&
updates
)
const
Output
<
Node
>&
updates
)
:
Op
(
check_single_output_args
({
inputs
,
indices
,
updates
})
)
:
Op
(
{
inputs
,
indices
,
updates
}
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
...
src/ngraph/op/select.cpp
View file @
6b5056e8
...
@@ -72,12 +72,12 @@ void op::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
...
@@ -72,12 +72,12 @@ void op::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
p
=
get_argument
(
0
);
auto
p
=
input
(
0
).
get_source_output
(
);
auto
x
=
get_argument
(
1
);
auto
x
=
input
(
1
).
get_source_output
(
);
auto
y
=
get_argument
(
2
);
auto
y
=
input
(
2
).
get_source_output
(
);
auto
p_as_x_type
=
make_shared
<
op
::
Convert
>
(
p
,
x
->
get_element_type
());
auto
p_as_x_type
=
make_shared
<
op
::
Convert
>
(
p
,
x
.
get_element_type
());
auto
not_p_as_y_type
=
make_shared
<
op
::
Convert
>
(
make_shared
<
op
::
Not
>
(
p
),
y
->
get_element_type
());
auto
not_p_as_y_type
=
make_shared
<
op
::
Convert
>
(
make_shared
<
op
::
Not
>
(
p
),
y
.
get_element_type
());
adjoints
.
add_delta
(
x
,
delta
*
p_as_x_type
);
adjoints
.
add_delta
(
x
,
delta
*
p_as_x_type
);
adjoints
.
add_delta
(
y
,
delta
*
not_p_as_y_type
);
adjoints
.
add_delta
(
y
,
delta
*
not_p_as_y_type
);
...
...
src/ngraph/op/sigmoid.cpp
View file @
6b5056e8
...
@@ -30,13 +30,13 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con
...
@@ -30,13 +30,13 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con
return
make_shared
<
Sigmoid
>
(
new_args
.
at
(
0
));
return
make_shared
<
Sigmoid
>
(
new_args
.
at
(
0
));
}
}
op
::
Sigmoid
::
Sigmoid
(
shared_ptr
<
Node
>
arg
)
op
::
Sigmoid
::
Sigmoid
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
op
::
SigmoidBackprop
::
SigmoidBackprop
(
shared_ptr
<
Node
>
arg
,
shared_ptr
<
Node
>
delta
)
op
::
SigmoidBackprop
::
SigmoidBackprop
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
delta
)
:
BinaryElementwiseArithmetic
(
arg
,
delta
)
:
BinaryElementwiseArithmetic
(
arg
,
delta
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -52,6 +52,6 @@ void op::Sigmoid::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
...
@@ -52,6 +52,6 @@ void op::Sigmoid::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
backprop
=
make_shared
<
op
::
SigmoidBackprop
>
(
get_argument
(
0
),
delta
);
auto
backprop
=
make_shared
<
op
::
SigmoidBackprop
>
(
input
(
0
).
get_source_output
(
),
delta
);
adjoints
.
add_delta
(
get_argument
(
0
),
backprop
);
adjoints
.
add_delta
(
input
(
0
).
get_source_output
(
),
backprop
);
}
}
src/ngraph/op/sigmoid.hpp
View file @
6b5056e8
...
@@ -31,7 +31,8 @@ namespace ngraph
...
@@ -31,7 +31,8 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Sigmoid
(
std
::
shared_ptr
<
Node
>
arg
);
Sigmoid
(
const
Output
<
Node
>&
arg
);
Sigmoid
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
...
@@ -46,10 +47,11 @@ namespace ngraph
...
@@ -46,10 +47,11 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
SigmoidBackprop
()
=
default
;
/// \brief Constructs a SigmoidBackprop operation.
/// \brief Constructs a SigmoidBackprop operation.
///
///
/// \param arg Node that produces the Sigmoid forward input tensor.
/// \param arg Node that produces the Sigmoid forward input tensor.
SigmoidBackprop
(
std
::
shared_ptr
<
ngraph
::
Node
>
arg
,
std
::
shared_ptr
<
ngraph
::
Node
>
delta
);
SigmoidBackprop
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
delta
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/sign.cpp
View file @
6b5056e8
...
@@ -21,7 +21,7 @@ using namespace ngraph;
...
@@ -21,7 +21,7 @@ using namespace ngraph;
const
string
op
::
Sign
::
type_name
{
"Sign"
};
const
string
op
::
Sign
::
type_name
{
"Sign"
};
op
::
Sign
::
Sign
(
const
shared_ptr
<
Node
>&
arg
)
op
::
Sign
::
Sign
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/sign.hpp
View file @
6b5056e8
...
@@ -30,10 +30,11 @@ namespace ngraph
...
@@ -30,10 +30,11 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Sign
()
=
default
;
/// \brief Constructs an elementwise sign operation.
/// \brief Constructs an elementwise sign operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Sign
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Sign
(
const
Output
<
Node
>&
arg
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/sin.cpp
View file @
6b5056e8
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
const
string
op
::
Sin
::
type_name
{
"Sin"
};
const
string
op
::
Sin
::
type_name
{
"Sin"
};
op
::
Sin
::
Sin
(
const
shared_ptr
<
Node
>&
arg
)
op
::
Sin
::
Sin
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -39,7 +39,7 @@ void op::Sin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
...
@@ -39,7 +39,7 @@ void op::Sin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
adjoints
.
add_delta
(
x
,
delta
*
(
make_shared
<
op
::
Cos
>
(
x
)));
adjoints
.
add_delta
(
x
,
delta
*
(
make_shared
<
op
::
Cos
>
(
x
)));
}
}
src/ngraph/op/sin.hpp
View file @
6b5056e8
...
@@ -44,7 +44,8 @@ namespace ngraph
...
@@ -44,7 +44,8 @@ namespace ngraph
/// \brief Constructs a sine operation.
/// \brief Constructs a sine operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Sin
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Sin
(
const
Output
<
Node
>&
arg
);
Sin
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/sinh.cpp
View file @
6b5056e8
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
const
string
op
::
Sinh
::
type_name
{
"Sinh"
};
const
string
op
::
Sinh
::
type_name
{
"Sinh"
};
op
::
Sinh
::
Sinh
(
const
shared_ptr
<
Node
>&
arg
)
op
::
Sinh
::
Sinh
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/sinh.hpp
View file @
6b5056e8
...
@@ -32,7 +32,8 @@ namespace ngraph
...
@@ -32,7 +32,8 @@ namespace ngraph
/// \brief Constructs a hyperbolic sine operation.
/// \brief Constructs a hyperbolic sine operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Sinh
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Sinh
(
const
Output
<
Node
>&
arg
);
Sinh
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/slice.cpp
View file @
6b5056e8
...
@@ -21,10 +21,6 @@ using namespace ngraph;
...
@@ -21,10 +21,6 @@ using namespace ngraph;
const
string
op
::
Slice
::
type_name
{
"Slice"
};
const
string
op
::
Slice
::
type_name
{
"Slice"
};
op
::
Slice
::
Slice
()
{
}
op
::
Slice
::
Slice
(
const
Output
<
Node
>&
arg
,
op
::
Slice
::
Slice
(
const
Output
<
Node
>&
arg
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
upper_bounds
,
const
Coordinate
&
upper_bounds
,
...
@@ -139,7 +135,7 @@ void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector
...
@@ -139,7 +135,7 @@ void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
adjoints
.
add_delta_to_slice
(
x
,
delta
,
m_lower_bounds
,
m_upper_bounds
,
m_strides
);
adjoints
.
add_delta_to_slice
(
x
,
delta
,
m_lower_bounds
,
m_upper_bounds
,
m_strides
);
}
}
src/ngraph/op/slice.hpp
View file @
6b5056e8
...
@@ -32,7 +32,7 @@ namespace ngraph
...
@@ -32,7 +32,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a tensor slice operation
/// \brief Constructs a tensor slice operation
Slice
();
Slice
()
=
default
;
/// \brief Constructs a tensor slice operation.
/// \brief Constructs a tensor slice operation.
///
///
/// \param arg The tensor to be sliced.
/// \param arg The tensor to be sliced.
...
...
src/ngraph/op/softmax.cpp
View file @
6b5056e8
...
@@ -31,7 +31,7 @@ using namespace ngraph;
...
@@ -31,7 +31,7 @@ using namespace ngraph;
const
string
op
::
Softmax
::
type_name
{
"Softmax"
};
const
string
op
::
Softmax
::
type_name
{
"Softmax"
};
op
::
Softmax
::
Softmax
(
const
shared_ptr
<
Node
>&
arg
,
const
AxisSet
&
axes
)
op
::
Softmax
::
Softmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
,
m_axes
(
axes
)
,
m_axes
(
axes
)
{
{
...
@@ -88,6 +88,6 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
...
@@ -88,6 +88,6 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto
adjoint
=
z
-
builder
::
make_with_numpy_broadcast
<
op
::
Multiply
>
(
output
(
0
),
zreshape
);
auto
adjoint
=
z
-
builder
::
make_with_numpy_broadcast
<
op
::
Multiply
>
(
output
(
0
),
zreshape
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
adjoints
.
add_delta
(
x
,
adjoint
);
adjoints
.
add_delta
(
x
,
adjoint
);
}
}
src/ngraph/op/softmax.hpp
View file @
6b5056e8
...
@@ -30,6 +30,7 @@ namespace ngraph
...
@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
NGRAPH_API
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
Softmax
()
=
default
;
/// \brief Constructs a softmax operation.
/// \brief Constructs a softmax operation.
///
///
/// \param arg Node that produces the first input tensor.<br>
/// \param arg Node that produces the first input tensor.<br>
...
@@ -38,12 +39,13 @@ namespace ngraph
...
@@ -38,12 +39,13 @@ namespace ngraph
///
///
/// Output `[d0, ...]`
/// Output `[d0, ...]`
///
///
Softmax
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
AxisSet
&
axes
);
Softmax
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
axes
);
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
const
AxisSet
&
get_axes
()
const
{
return
m_axes
;
}
void
set_axes
(
const
AxisSet
&
axes
)
{
m_axes
=
axes
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
...
...
src/ngraph/op/sqrt.cpp
View file @
6b5056e8
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
const
string
op
::
Sqrt
::
type_name
{
"Sqrt"
};
const
string
op
::
Sqrt
::
type_name
{
"Sqrt"
};
op
::
Sqrt
::
Sqrt
(
const
shared_ptr
<
Node
>&
arg
)
op
::
Sqrt
::
Sqrt
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -39,7 +39,7 @@ void op::Sqrt::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
...
@@ -39,7 +39,7 @@ void op::Sqrt::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
adjoints
.
add_delta
(
x
,
delta
/
(
shared_from_this
()
+
shared_from_this
()));
adjoints
.
add_delta
(
x
,
delta
/
(
shared_from_this
()
+
shared_from_this
()));
}
}
src/ngraph/op/sqrt.hpp
View file @
6b5056e8
...
@@ -44,7 +44,8 @@ namespace ngraph
...
@@ -44,7 +44,8 @@ namespace ngraph
/// \brief Constructs a square operation.
/// \brief Constructs a square operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Sqrt
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Sqrt
(
const
Output
<
Node
>&
arg
);
Sqrt
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/stop_gradient.cpp
View file @
6b5056e8
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
const
string
op
::
StopGradient
::
type_name
{
"StopGradient"
};
const
string
op
::
StopGradient
::
type_name
{
"StopGradient"
};
op
::
StopGradient
::
StopGradient
(
const
shared_ptr
<
Node
>&
arg
)
op
::
StopGradient
::
StopGradient
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/stop_gradient.hpp
View file @
6b5056e8
...
@@ -32,7 +32,8 @@ namespace ngraph
...
@@ -32,7 +32,8 @@ namespace ngraph
/// \brief Constructs StopGradient
/// \brief Constructs StopGradient
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
StopGradient
(
const
std
::
shared_ptr
<
Node
>&
arg
);
StopGradient
(
const
Output
<
Node
>&
arg
);
StopGradient
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/subtract.cpp
View file @
6b5056e8
...
@@ -45,8 +45,8 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
...
@@ -45,8 +45,8 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
auto
y
=
get_argument
(
1
);
auto
y
=
input
(
1
).
get_source_output
(
);
adjoints
.
add_delta
(
x
,
delta
);
adjoints
.
add_delta
(
x
,
delta
);
adjoints
.
add_delta
(
y
,
-
delta
);
adjoints
.
add_delta
(
y
,
-
delta
);
...
...
src/ngraph/op/sum.cpp
View file @
6b5056e8
...
@@ -22,10 +22,6 @@ using namespace ngraph;
...
@@ -22,10 +22,6 @@ using namespace ngraph;
const
string
op
::
Sum
::
type_name
{
"Sum"
};
const
string
op
::
Sum
::
type_name
{
"Sum"
};
op
::
Sum
::
Sum
()
{
}
op
::
Sum
::
Sum
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
op
::
Sum
::
Sum
(
const
Output
<
Node
>&
arg
,
const
AxisSet
&
reduction_axes
)
:
ArithmeticReduction
(
arg
,
reduction_axes
)
:
ArithmeticReduction
(
arg
,
reduction_axes
)
{
{
...
...
src/ngraph/op/sum.hpp
View file @
6b5056e8
...
@@ -78,7 +78,7 @@ namespace ngraph
...
@@ -78,7 +78,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a summation operation.
/// \brief Constructs a summation operation.
Sum
();
Sum
()
=
default
;
/// \brief Constructs a summation operation.
/// \brief Constructs a summation operation.
///
///
/// \param arg The tensor to be summed.
/// \param arg The tensor to be summed.
...
...
src/ngraph/op/tan.cpp
View file @
6b5056e8
...
@@ -24,7 +24,7 @@ using namespace ngraph;
...
@@ -24,7 +24,7 @@ using namespace ngraph;
const
string
op
::
Tan
::
type_name
{
"Tan"
};
const
string
op
::
Tan
::
type_name
{
"Tan"
};
op
::
Tan
::
Tan
(
const
shared_ptr
<
Node
>&
arg
)
op
::
Tan
::
Tan
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -40,7 +40,7 @@ void op::Tan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
...
@@ -40,7 +40,7 @@ void op::Tan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
auto
c
=
make_shared
<
op
::
Cos
>
(
x
);
auto
c
=
make_shared
<
op
::
Cos
>
(
x
);
...
...
src/ngraph/op/tan.hpp
View file @
6b5056e8
...
@@ -44,7 +44,8 @@ namespace ngraph
...
@@ -44,7 +44,8 @@ namespace ngraph
/// \brief Constructs a tangent operation.
/// \brief Constructs a tangent operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Tan
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Tan
(
const
Output
<
Node
>&
arg
);
Tan
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/tanh.cpp
View file @
6b5056e8
...
@@ -23,7 +23,7 @@ using namespace ngraph;
...
@@ -23,7 +23,7 @@ using namespace ngraph;
const
string
op
::
Tanh
::
type_name
{
"Tanh"
};
const
string
op
::
Tanh
::
type_name
{
"Tanh"
};
op
::
Tanh
::
Tanh
(
const
shared_ptr
<
Node
>&
arg
)
op
::
Tanh
::
Tanh
(
const
Output
<
Node
>&
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
:
UnaryElementwiseArithmetic
(
arg
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
@@ -39,7 +39,7 @@ void op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
...
@@ -39,7 +39,7 @@ void op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{
{
auto
delta
=
deltas
.
at
(
0
);
auto
delta
=
deltas
.
at
(
0
);
auto
x
=
get_argument
(
0
);
auto
x
=
input
(
0
).
get_source_output
(
);
adjoints
.
add_delta
(
x
,
delta
-
(
delta
*
(
shared_from_this
()
*
shared_from_this
())));
adjoints
.
add_delta
(
x
,
delta
-
(
delta
*
(
shared_from_this
()
*
shared_from_this
())));
}
}
src/ngraph/op/tanh.hpp
View file @
6b5056e8
...
@@ -32,7 +32,8 @@ namespace ngraph
...
@@ -32,7 +32,8 @@ namespace ngraph
/// \brief Constructs a hyperbolic tangent operation.
/// \brief Constructs a hyperbolic tangent operation.
///
///
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Tanh
(
const
std
::
shared_ptr
<
Node
>&
arg
);
Tanh
(
const
Output
<
Node
>&
arg
);
Tanh
()
=
default
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/topk.cpp
View file @
6b5056e8
...
@@ -26,10 +26,6 @@ using namespace ngraph;
...
@@ -26,10 +26,6 @@ using namespace ngraph;
const
string
op
::
TopK
::
type_name
{
"TopK"
};
const
string
op
::
TopK
::
type_name
{
"TopK"
};
op
::
TopK
::
TopK
()
{
}
op
::
TopK
::
TopK
(
const
Output
<
Node
>&
arg
,
op
::
TopK
::
TopK
(
const
Output
<
Node
>&
arg
,
size_t
top_k_axis
,
size_t
top_k_axis
,
const
element
::
Type
&
index_element_type
,
const
element
::
Type
&
index_element_type
,
...
@@ -63,7 +59,8 @@ op::TopK::TopK(const Output<Node>& arg,
...
@@ -63,7 +59,8 @@ op::TopK::TopK(const Output<Node>& arg,
size_t
op
::
TopK
::
get_k
()
const
size_t
op
::
TopK
::
get_k
()
const
{
{
size_t
k
=
0
;
size_t
k
=
0
;
if
(
auto
const_op
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
1
)))
if
(
auto
const_op
=
dynamic_pointer_cast
<
op
::
Constant
>
(
input
(
1
).
get_source_output
().
get_node_shared_ptr
()))
{
{
k
=
const_op
->
get_vector
<
int64_t
>
()[
0
];
k
=
const_op
->
get_vector
<
int64_t
>
()[
0
];
}
}
...
...
src/ngraph/op/topk.hpp
View file @
6b5056e8
...
@@ -44,7 +44,7 @@ namespace ngraph
...
@@ -44,7 +44,7 @@ namespace ngraph
static
const
std
::
string
type_name
;
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a TopK operation
/// \brief Constructs a TopK operation
TopK
();
TopK
()
=
default
;
/// \brief Constructs a TopK operation.
/// \brief Constructs a TopK operation.
///
///
/// \param arg The input tensor
/// \param arg The input tensor
...
...
src/ngraph/serializer.cpp
View file @
6b5056e8
...
@@ -1576,7 +1576,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1576,7 +1576,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node
=
make_shared
<
op
::
Passthrough
>
(
node_js
.
at
(
"logical_type"
),
node
=
make_shared
<
op
::
Passthrough
>
(
node_js
.
at
(
"logical_type"
),
node_js
.
at
(
"language"
),
node_js
.
at
(
"language"
),
node_js
.
at
(
"function"
),
node_js
.
at
(
"function"
),
args
,
static_cast
<
OutputVector
>
(
args
)
,
std
::
move
(
outputs
));
std
::
move
(
outputs
));
break
;
break
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment