Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
239de984
Commit
239de984
authored
Jun 25, 2019
by
Adam Rogowiec
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix errors after merge and add default values for optional inputs.
parent
61b935c2
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
139 additions
and
266 deletions
+139
-266
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-2
gru_cell.cpp
src/ngraph/op/fused/gru_cell.cpp
+17
-18
gru_cell.hpp
src/ngraph/op/fused/gru_cell.hpp
+7
-4
rnn_cell.cpp
src/ngraph/op/fused/rnn_cell.cpp
+29
-17
rnn_cell.hpp
src/ngraph/op/fused/rnn_cell.hpp
+6
-3
rnn_cell_base.cpp
src/ngraph/op/fused/rnn_cell_base.cpp
+0
-98
rnn_cell_base.hpp
src/ngraph/op/fused/rnn_cell_base.hpp
+0
-118
rnn_cell_base.cpp
src/ngraph/op/util/rnn_cell_base.cpp
+37
-0
rnn_cell_base.hpp
src/ngraph/op/util/rnn_cell_base.hpp
+43
-6
No files found.
src/ngraph/CMakeLists.txt
View file @
239de984
...
...
@@ -326,8 +326,6 @@ set (SRC
op/fused/prelu.hpp
op/fused/rnn_cell.cpp
op/fused/rnn_cell.hpp
op/fused/rnn_cell_base.cpp
op/fused/rnn_cell_base.hpp
op/fused/scale_shift.cpp
op/fused/scale_shift.hpp
op/fused/shuffle_channels.cpp
...
...
src/ngraph/op/fused/gru_cell.cpp
View file @
239de984
...
...
@@ -19,11 +19,11 @@
#include <functional>
#include "ngraph/builder/split.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
...
...
@@ -65,6 +65,7 @@ op::GRUCell::GRUCell(const shared_ptr<Node>& X,
,
m_activation_g
{
get_activation_function
(
1
)}
,
m_linear_before_reset
{
linear_before_reset
}
{
add_default_bias_input
();
constructor_validate_and_infer_types
();
}
...
...
@@ -137,8 +138,6 @@ void op::GRUCell::pre_validate_and_infer_types()
w_shape
,
"."
);
if
(
get_input_size
()
==
5
)
{
const
auto
&
b_pshape
=
get_input_partial_shape
(
4
);
NODE_VALIDATION_CHECK
(
...
...
@@ -153,7 +152,6 @@ void op::GRUCell::pre_validate_and_infer_types()
"). Actual shape is:"
,
b_shape
,
"."
);
}
}
NodeVector
op
::
GRUCell
::
decompose_op
()
const
...
...
@@ -225,7 +223,7 @@ NodeVector op::GRUCell::decompose_op() const
const
auto
&
R_h
=
R_zr_h
.
at
(
1
);
// Xt*(W^T)
auto
Xt_W
=
make_shared
<
op
::
Dot
>
(
X
,
op
::
util
::
transpose
(
W
));
auto
Xt_W
=
make_shared
<
op
::
Dot
>
(
X
,
builder
::
transpose
(
W
));
// Split Xt_W into zr and h gates.
NodeVector
Xt_W_zr_h
=
builder
::
split
(
Xt_W
,
vector
<
size_t
>
{
2
*
get_hidden_size
(),
get_hidden_size
()},
1
);
...
...
@@ -235,7 +233,7 @@ NodeVector op::GRUCell::decompose_op() const
const
auto
&
Xt_W_h
=
Xt_W_zr_h
.
at
(
1
);
// Ht-1*(R^T) for update and reset gates. Tensor shape: [batch_size, 2 * hidden_size]
auto
Ht_R_zr
=
make_shared
<
op
::
Dot
>
(
H_t
,
op
::
util
::
transpose
(
R_zr
));
auto
Ht_R_zr
=
make_shared
<
op
::
Dot
>
(
H_t
,
builder
::
transpose
(
R_zr
));
// f(Xt*(W^T) + Ht-1*(R^T) + Wb + Rb) for update and reset gates.
// Tensor shape: [batch_size, 2 * hidden_size]
auto
zr_t
=
m_activation_f
(
clip
(
add
(
Xt_W_zr
,
add
(
Ht_R_zr
,
add
(
Wb_zr
,
Rb_zr
)))));
...
...
@@ -250,14 +248,14 @@ NodeVector op::GRUCell::decompose_op() const
if
(
m_linear_before_reset
)
{
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
Ht_Rh_Rb
=
add
(
make_shared
<
op
::
Dot
>
(
H_t
,
op
::
util
::
transpose
(
R_h
)),
Rb_h
);
auto
Ht_Rh_Rb
=
add
(
make_shared
<
op
::
Dot
>
(
H_t
,
builder
::
transpose
(
R_h
)),
Rb_h
);
h_t
=
m_activation_g
(
clip
(
add
(
Xt_W_h
,
add
(
mul
(
r_t
,
Ht_Rh_Rb
),
Wb_h
))));
}
else
{
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
rt_Ht
=
mul
(
r_t
,
H_t
);
auto
rt_Ht_Rh
=
make_shared
<
op
::
Dot
>
(
rt_Ht
,
op
::
util
::
transpose
(
R_h
));
auto
rt_Ht_Rh
=
make_shared
<
op
::
Dot
>
(
rt_Ht
,
builder
::
transpose
(
R_h
));
// Tensor shape: [batch_size, hidden_size]
h_t
=
m_activation_g
(
clip
(
add
(
Xt_W_h
,
add
(
rt_Ht_Rh
,
add
(
Rb_h
,
Wb_h
)))));
}
...
...
@@ -275,18 +273,19 @@ NodeVector op::GRUCell::decompose_op() const
shared_ptr
<
Node
>
op
::
GRUCell
::
get_bias
()
const
{
shared_ptr
<
Node
>
bias
;
if
(
get_input_size
()
==
5
)
{
bias
=
get_argument
(
4
);
}
else
{
// As default bias is all zeros, thus just initialize it with appropriate shape and zeros.
bias
=
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
// Split B onto Wb an Rb and add them.
NodeVector
b_W_R
=
builder
::
split
(
get_argument
(
4
),
2
);
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
return
bias
;
}
void
op
::
GRUCell
::
add_default_bias_input
()
{
shared_ptr
<
Node
>
B
=
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
Shape
{
2
*
m_gates_count
*
get_hidden_size
()},
vector
<
float
>
(
2
*
m_gates_count
*
get_hidden_size
(),
0.
f
));
}
return
bias
;
set_argument
(
4
,
B
->
output
(
0
));
}
shared_ptr
<
Node
>
op
::
GRUCell
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/fused/gru_cell.hpp
View file @
239de984
...
...
@@ -22,9 +22,9 @@
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/fused/rnn_cell_base.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace
ngraph
{
...
...
@@ -38,7 +38,7 @@ namespace ngraph
///
/// Note this class represents only single *cell* and not whole GRU *layer*.
///
class
GRUCell
:
public
util
::
FusedOp
,
public
RNNCellBase
class
GRUCell
:
public
util
::
FusedOp
,
public
util
::
RNNCellBase
{
public
:
///
...
...
@@ -134,14 +134,17 @@ namespace ngraph
private
:
std
::
shared_ptr
<
Node
>
get_bias
()
const
;
/// brief Add and initialize bias input to all zeros.
void
add_default_bias_input
();
///
/// \brief The Activation function f.
///
ActivationFunction
m_activation_f
;
util
::
ActivationFunction
m_activation_f
;
///
/// \brief The Activation function g.
///
ActivationFunction
m_activation_g
;
util
::
ActivationFunction
m_activation_g
;
static
constexpr
std
::
size_t
m_gates_count
{
3
};
///
...
...
src/ngraph/op/fused/rnn_cell.cpp
View file @
239de984
...
...
@@ -19,11 +19,11 @@
#include <functional>
#include "ngraph/builder/split.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
...
...
@@ -54,6 +54,7 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
,
RNNCellBase
(
hidden_size
,
clip
,
activations
,
activation_alpha
,
activation_beta
)
,
m_activation_f
{
get_activation_function
(
0
)}
{
add_default_bias_input
();
constructor_validate_and_infer_types
();
}
...
...
@@ -123,8 +124,6 @@ void op::RNNCell::pre_validate_and_infer_types()
w_shape
,
"."
);
if
(
get_input_size
()
==
5
)
{
const
auto
&
b_pshape
=
get_input_partial_shape
(
4
);
NODE_VALIDATION_CHECK
(
...
...
@@ -139,7 +138,6 @@ void op::RNNCell::pre_validate_and_infer_types()
"). Actual shape is:"
,
b_shape
,
"."
);
}
}
NodeVector
op
::
RNNCell
::
decompose_op
()
const
...
...
@@ -179,9 +177,9 @@ NodeVector op::RNNCell::decompose_op() const
auto
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
// Xt*(W^T)
auto
Xt_W
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
X
,
ngraph
::
op
::
util
::
transpose
(
W
));
auto
Xt_W
=
std
::
make_shared
<
op
::
Dot
>
(
X
,
builder
::
transpose
(
W
));
// Ht-1*(R^T)
auto
Ht_R
=
std
::
make_shared
<
ngraph
::
op
::
Dot
>
(
H_t
,
ngraph
::
op
::
util
::
transpose
(
R
));
auto
Ht_R
=
std
::
make_shared
<
op
::
Dot
>
(
H_t
,
builder
::
transpose
(
R
));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto
i_t
=
add
(
Xt_W
,
add
(
Ht_R
,
bias
));
...
...
@@ -193,21 +191,35 @@ NodeVector op::RNNCell::decompose_op() const
shared_ptr
<
Node
>
op
::
RNNCell
::
get_bias
()
const
{
// shared_ptr<Node> bias;
// if (get_input_size() == 5)
// {
// bias = get_argument(4);
// }
// else
// {
// // As default bias is all zeros, thus just initialize it with appropriate shape and zeros.
// bias = op::Constant::create(input(0).get_element_type(),
// Shape{2 * get_hidden_size()},
// vector<float>(2 * get_hidden_size(), 0.f));
// }
// return bias;
shared_ptr
<
Node
>
bias
;
if
(
get_input_size
()
==
5
)
{
bias
=
get_argument
(
4
);
}
else
{
// As default bias is all zeros, thus just initialize it with appropriate shape and zeros.
bias
=
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
Shape
{
2
*
get_hidden_size
()},
vector
<
float
>
(
2
*
get_hidden_size
(),
0.
f
));
}
// Split B onto Wb an Rb and add them.
NodeVector
b_W_R
=
builder
::
split
(
get_argument
(
4
),
2
);
bias
=
b_W_R
.
at
(
0
)
+
b_W_R
.
at
(
1
);
return
bias
;
}
void
op
::
RNNCell
::
add_default_bias_input
()
{
shared_ptr
<
Node
>
B
=
op
::
Constant
::
create
(
input
(
0
).
get_element_type
(),
Shape
{
2
*
m_gates_count
*
get_hidden_size
()},
vector
<
float
>
(
2
*
m_gates_count
*
get_hidden_size
(),
0.
f
));
set_argument
(
4
,
B
->
output
(
0
));
}
shared_ptr
<
Node
>
op
::
RNNCell
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/fused/rnn_cell.hpp
View file @
239de984
...
...
@@ -22,9 +22,9 @@
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/fused/rnn_cell_base.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace
ngraph
{
...
...
@@ -38,7 +38,7 @@ namespace ngraph
///
/// Note this class represents only single *cell* and not whole RNN *layer*.
///
class
RNNCell
:
public
util
::
FusedOp
,
public
RNNCellBase
class
RNNCell
:
public
util
::
FusedOp
,
public
util
::
RNNCellBase
{
public
:
///
...
...
@@ -126,10 +126,13 @@ namespace ngraph
private
:
std
::
shared_ptr
<
Node
>
get_bias
()
const
;
/// brief Add and initialize bias input to all zeros.
void
add_default_bias_input
();
///
/// \brief The Activation function f.
///
ActivationFunction
m_activation_f
;
util
::
ActivationFunction
m_activation_f
;
static
constexpr
std
::
size_t
m_gates_count
{
1
};
};
...
...
src/ngraph/op/fused/rnn_cell_base.cpp
deleted
100644 → 0
View file @
61b935c2
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iterator>
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/rnn_cell_base.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
// Modify input vector in-place and return reference to modified vector.
static
vector
<
string
>
to_lower_case
(
const
vector
<
string
>&
vs
)
{
vector
<
string
>
res
(
vs
);
transform
(
begin
(
res
),
end
(
res
),
begin
(
res
),
[](
string
&
s
)
{
return
to_lower
(
s
);
});
return
res
;
}
op
::
RNNCellBase
::
RNNCellBase
(
size_t
hidden_size
,
float
clip
,
const
vector
<
string
>&
activations
,
const
vector
<
float
>&
activation_alpha
,
const
vector
<
float
>&
activation_beta
)
:
m_hidden_size
(
hidden_size
)
,
m_clip
(
clip
)
,
m_activations
(
to_lower_case
(
activations
))
,
m_activation_alpha
(
activation_alpha
)
,
m_activation_beta
(
activation_beta
)
{
}
op
::
ActivationFunction
op
::
RNNCellBase
::
get_activation_function
(
size_t
idx
)
const
{
op
::
ActivationFunction
afunc
=
get_activation_func_by_name
(
m_activations
.
at
(
idx
));
// Set activation functions parameters (if any)
if
(
m_activation_alpha
.
size
()
>
idx
)
{
afunc
.
set_alpha
(
m_activation_alpha
.
at
(
idx
));
}
if
(
m_activation_beta
.
size
()
>
idx
)
{
afunc
.
set_beta
(
m_activation_beta
.
at
(
idx
));
}
return
afunc
;
}
shared_ptr
<
Node
>
op
::
RNNCellBase
::
add
(
const
shared_ptr
<
Node
>&
lhs
,
const
shared_ptr
<
Node
>&
rhs
)
{
auto
args
=
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
make_shared
<
op
::
Add
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
shared_ptr
<
Node
>
op
::
RNNCellBase
::
sub
(
const
shared_ptr
<
Node
>&
lhs
,
const
shared_ptr
<
Node
>&
rhs
)
{
auto
args
=
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
make_shared
<
op
::
Subtract
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
shared_ptr
<
Node
>
op
::
RNNCellBase
::
mul
(
const
shared_ptr
<
Node
>&
lhs
,
const
shared_ptr
<
Node
>&
rhs
)
{
auto
args
=
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
make_shared
<
op
::
Multiply
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
shared_ptr
<
Node
>
op
::
RNNCellBase
::
clip
(
const
shared_ptr
<
Node
>&
data
)
const
{
if
(
m_clip
==
0.
f
)
{
return
data
;
}
return
make_shared
<
op
::
Clamp
>
(
data
,
-
m_clip
,
m_clip
);
}
src/ngraph/op/fused/rnn_cell_base.hpp
deleted
100644 → 0
View file @
61b935c2
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
namespace
ngraph
{
namespace
op
{
/// \brief Base class for all recurrent network cells.
///
/// \note It holds all common attributes.
///
class
RNNCellBase
{
public
:
///
/// \brief Constructs a RNNCellBase class.
///
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
///
RNNCellBase
(
std
::
size_t
hidden_size
,
float
clip
,
const
std
::
vector
<
std
::
string
>&
activations
,
const
std
::
vector
<
float
>&
activation_alpha
,
const
std
::
vector
<
float
>&
activation_beta
);
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
float
get_clip
()
const
{
return
m_clip
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
const
std
::
vector
<
float
>&
get_activation_alpha
()
const
{
return
m_activation_alpha
;
}
const
std
::
vector
<
float
>&
get_activation_beta
()
const
{
return
m_activation_beta
;
}
protected
:
///
/// \brief Constructs activation function object.
///
/// \param[in] idx The index of the activation function name.
///
/// \return The object representing activation function.
///
ActivationFunction
get_activation_function
(
std
::
size_t
idx
)
const
;
///
/// \brief Creates node with element-wise add operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise add operation.
///
static
std
::
shared_ptr
<
Node
>
add
(
const
std
::
shared_ptr
<
Node
>&
lhs
,
const
std
::
shared_ptr
<
Node
>&
rhs
);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise subtract operation.
///
static
std
::
shared_ptr
<
Node
>
sub
(
const
std
::
shared_ptr
<
Node
>&
lhs
,
const
std
::
shared_ptr
<
Node
>&
rhs
);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise multiply operation.
///
static
std
::
shared_ptr
<
Node
>
mul
(
const
std
::
shared_ptr
<
Node
>&
lhs
,
const
std
::
shared_ptr
<
Node
>&
rhs
);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
/// \param[in] data The input tensor for clipping.
///
/// \return Node with element-wise clip operation.
///
std
::
shared_ptr
<
Node
>
clip
(
const
std
::
shared_ptr
<
Node
>&
data
)
const
;
private
:
const
std
::
size_t
m_hidden_size
=
0.
f
;
const
float
m_clip
=
0.
f
;
const
std
::
vector
<
std
::
string
>
m_activations
;
const
std
::
vector
<
float
>
m_activation_alpha
;
const
std
::
vector
<
float
>
m_activation_beta
;
};
}
}
src/ngraph/op/util/rnn_cell_base.cpp
View file @
239de984
...
...
@@ -17,6 +17,14 @@
#include <algorithm>
#include <iterator>
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/util.hpp"
...
...
@@ -60,3 +68,32 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
return
afunc
;
}
shared_ptr
<
Node
>
op
::
util
::
RNNCellBase
::
add
(
const
shared_ptr
<
Node
>&
lhs
,
const
shared_ptr
<
Node
>&
rhs
)
{
auto
args
=
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
make_shared
<
op
::
Add
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
shared_ptr
<
Node
>
op
::
util
::
RNNCellBase
::
sub
(
const
shared_ptr
<
Node
>&
lhs
,
const
shared_ptr
<
Node
>&
rhs
)
{
auto
args
=
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
make_shared
<
op
::
Subtract
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
shared_ptr
<
Node
>
op
::
util
::
RNNCellBase
::
mul
(
const
shared_ptr
<
Node
>&
lhs
,
const
shared_ptr
<
Node
>&
rhs
)
{
auto
args
=
op
::
numpy_style_broadcast
({
lhs
,
rhs
});
return
{
make_shared
<
op
::
Multiply
>
(
args
.
at
(
0
),
args
.
at
(
1
))};
}
shared_ptr
<
Node
>
op
::
util
::
RNNCellBase
::
clip
(
const
shared_ptr
<
Node
>&
data
)
const
{
if
(
m_clip
==
0.
f
)
{
return
data
;
}
return
make_shared
<
op
::
Clamp
>
(
data
,
-
m_clip
,
m_clip
);
}
src/ngraph/op/util/rnn_cell_base.hpp
View file @
239de984
...
...
@@ -17,9 +17,11 @@
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
namespace
ngraph
...
...
@@ -57,10 +59,7 @@ namespace ngraph
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
float
get_clip
()
const
{
return
m_clip
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
const
std
::
vector
<
float
>&
get_activation_alpha
()
const
{
return
m_activation_alpha
;
}
const
std
::
vector
<
float
>&
get_activation_alpha
()
const
{
return
m_activation_alpha
;
}
const
std
::
vector
<
float
>&
get_activation_beta
()
const
{
return
m_activation_beta
;
}
protected
:
///
...
...
@@ -71,10 +70,48 @@ namespace ngraph
/// \return The object representing activation function.
///
ActivationFunction
get_activation_function
(
std
::
size_t
idx
)
const
;
///
/// \brief Creates node with element-wise add operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise add operation.
///
static
std
::
shared_ptr
<
Node
>
add
(
const
std
::
shared_ptr
<
Node
>&
lhs
,
const
std
::
shared_ptr
<
Node
>&
rhs
);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise subtract operation.
///
static
std
::
shared_ptr
<
Node
>
sub
(
const
std
::
shared_ptr
<
Node
>&
lhs
,
const
std
::
shared_ptr
<
Node
>&
rhs
);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise multiply operation.
///
static
std
::
shared_ptr
<
Node
>
mul
(
const
std
::
shared_ptr
<
Node
>&
lhs
,
const
std
::
shared_ptr
<
Node
>&
rhs
);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
/// \param[in] data The input tensor for clipping.
///
/// \return Node with element-wise clip operation.
///
std
::
shared_ptr
<
Node
>
clip
(
const
std
::
shared_ptr
<
Node
>&
data
)
const
;
private
:
std
::
size_t
m_hidden_size
=
0.
f
;
float
m_clip
=
0.
f
;
const
std
::
size_t
m_hidden_size
=
0.
f
;
const
float
m_clip
=
0.
f
;
const
std
::
vector
<
std
::
string
>
m_activations
;
const
std
::
vector
<
float
>
m_activation_alpha
;
const
std
::
vector
<
float
>
m_activation_beta
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment