Unverified Commit 6ea70ac6 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by GitHub

Merge branch 'master' into arogowiec/lstm_fix

parents 7ce7e037 e687fe7e
......@@ -318,6 +318,8 @@ set (SRC
op/fused/group_conv.cpp
op/fused/group_conv_transpose.hpp
op/fused/group_conv_transpose.cpp
op/fused/gru_cell.cpp
op/fused/gru_cell.hpp
op/fused/leaky_relu.cpp
op/fused/leaky_relu.hpp
op/fused/lstm_cell.cpp
......@@ -328,6 +330,8 @@ set (SRC
op/fused/normalize.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/fused/rnn_cell.cpp
op/fused/rnn_cell.hpp
op/fused/scale_shift.cpp
op/fused/scale_shift.hpp
op/fused/shuffle_channels.cpp
......
......@@ -25,7 +25,6 @@
#include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
......
......@@ -105,12 +105,14 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
......
......@@ -99,7 +99,8 @@ void op::DynReshape::validate_and_infer_types()
if (out_shape_val[i] == 0 && m_zero_flag)
{
// Copy input_shape[i] for zero values
NGRAPH_CHECK(i < input_shape.size());
NODE_VALIDATION_CHECK(
this, i < input_shape.size(), "'0' dimension is out of range");
partial_shape[i] = Dimension(input_shape[i]);
output_elements *= input_shape[i];
}
......@@ -119,12 +120,21 @@ void op::DynReshape::validate_and_infer_types()
// input elements
if (output_elements == 0)
{
NGRAPH_CHECK(input_elements == 0);
// TODO(amprocte): Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_elements == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
partial_shape[negative_dim] = Dimension(0);
}
else
{
NGRAPH_CHECK(input_elements % output_elements == 0);
NODE_VALIDATION_CHECK(
this,
input_elements % output_elements == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
partial_shape[negative_dim] = Dimension(input_elements / output_elements);
}
}
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for GRU cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
///
/// Note this class represents only single *cell* and not whole GRU *layer*.
///
class GRUCell : public util::FusedOp, public util::RNNCellBase
{
public:
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip,
bool linear_before_reset);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f,
bool linear_before_reset = false);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_linear_before_reset() const { return m_linear_before_reset; }
private:
/// brief Add and initialize bias input to all zeros.
void add_default_bias_input();
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
util::ActivationFunction m_activation_g;
static constexpr std::size_t s_gates_count{3};
///
/// \brief Control whether or not apply the linear transformation.
///
/// \note The linear transformation may be applied when computing the output of hidden gate.
/// It's done before multiplying by the output of the reset gate.
///
bool m_linear_before_reset;
};
}
}
......@@ -24,11 +24,6 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -36,46 +31,6 @@
using namespace std;
using namespace ngraph;
// ------------- HELPER FUNCTIONS ---------------------------------------------
static shared_ptr<Node> add(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
static shared_ptr<Node> sub(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
static shared_ptr<Node> mul(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
static shared_ptr<Node> clip(const shared_ptr<Node>& data, float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
size_t size = shape_size(data->get_shape());
const shared_ptr<Node> min_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, min_val));
const shared_ptr<Node> max_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, max_val));
return make_shared<op::Minimum>(max_val_node, make_shared<op::Maximum>(data, min_val_node));
}
// ------------- LSTM_CELL ----------------------------------------------------
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
......@@ -166,18 +121,18 @@ void op::LSTMCell::pre_validate_and_infer_types()
const Shape& ct_shape{ct_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{m_gates_count * get_hidden_size(), input_size}),
(w_shape == Shape{s_gates_count * get_hidden_size(), input_size}),
"Input tensor W must have shape (",
m_gates_count * get_hidden_size(),
s_gates_count * get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{m_gates_count * get_hidden_size(), get_hidden_size()}),
(r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
m_gates_count * get_hidden_size(),
s_gates_count * get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
......@@ -213,7 +168,7 @@ void op::LSTMCell::pre_validate_and_infer_types()
const Shape& p_shape{p_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * m_gates_count * get_hidden_size()}),
(b_shape == Shape{2 * s_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
......@@ -221,9 +176,9 @@ void op::LSTMCell::pre_validate_and_infer_types()
".");
NODE_VALIDATION_CHECK(this,
(p_shape == Shape{m_peepholes_count * get_hidden_size()}),
(p_shape == Shape{s_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (",
m_peepholes_count * get_hidden_size(),
s_peepholes_count * get_hidden_size(),
"). Actual shape is:",
p_shape,
".");
......@@ -295,7 +250,7 @@ NodeVector op::LSTMCell::decompose_op() const
auto c_t = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t)), get_clip()));
i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t))));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i_t
......@@ -307,14 +262,14 @@ NodeVector op::LSTMCell::decompose_op() const
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t)), get_clip()));
f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t))));
}
// ft (.) Ct-1 + it (.) ct
auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t, get_clip()))));
auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o_t = m_activation_f(clip(add(o_t, mul(p_o, C)), get_clip()));
o_t = m_activation_f(clip(add(o_t, mul(p_o, C))));
// ot (.) h(Ct)
auto H = mul(o_t, m_activation_h(clip(C, get_clip())));
auto H = mul(o_t, m_activation_h(clip(C)));
return {H, C};
}
......@@ -332,15 +287,15 @@ NodeVector op::LSTMCell::get_peephole_weigths() const
{
shared_ptr<Node> P;
P = get_argument(6);
return builder::split(P, m_peepholes_count);
return builder::split(P, s_peepholes_count);
}
void op::LSTMCell::add_default_bias_input()
{
shared_ptr<Node> B =
op::Constant::create(input(0).get_element_type(),
Shape{2 * m_gates_count * get_hidden_size()},
vector<float>(2 * m_gates_count * get_hidden_size(), 0.f));
Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(5, B->output(0));
}
......@@ -348,8 +303,8 @@ void op::LSTMCell::add_default_peepholes_input()
{
shared_ptr<Node> P =
op::Constant::create(input(0).get_element_type(),
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
Shape{s_peepholes_count * get_hidden_size()},
vector<float>(s_peepholes_count * get_hidden_size(), 0.f));
set_argument(6, P->output(0));
}
......
......@@ -168,8 +168,8 @@ namespace ngraph
///
bool m_input_forget = false;
static constexpr std::size_t m_gates_count{4};
static constexpr std::size_t m_peepholes_count{3};
static constexpr std::size_t s_gates_count{4};
static constexpr std::size_t s_peepholes_count{3};
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <functional>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size)
: RNNCell(
X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f)
{
}
op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip)
: FusedOp("RNNCell", {X, W, R, H_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)}
{
add_default_bias_input();
constructor_validate_and_infer_types();
}
op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size,
const shared_ptr<Node>& B,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip)
: FusedOp("RNNCell", {X, W, R, H_t, B})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)}
{
constructor_validate_and_infer_types();
}
void op::RNNCell::pre_validate_and_infer_types()
{
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static()),
"RNNCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{get_hidden_size(), input_size}),
"Input tensor W must have shape (",
get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor H_t must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
const auto& b_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(
this, b_pshape.is_static(), "RNNCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * get_hidden_size()}),
"Input tensor B must have shape (",
2 * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
}
NodeVector op::RNNCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i_t - input gate at current time step
// t - time step (t-1 means previous time step)
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The bias tensor for the input gate. Shape: [2 * hidden_size] Concatenation of `[Wb, Rb]`.
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// Xt_W - Input sequence multiplied by weights tensor at current time step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// ---- Equations ----
// f - is activation functions.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// --------------------
std::shared_ptr<Node> X = get_argument(0);
std::shared_ptr<Node> W = get_argument(1);
std::shared_ptr<Node> R = get_argument(2);
std::shared_ptr<Node> H_t = get_argument(3);
std::shared_ptr<Node> bias = get_bias();
// Xt*(W^T)
auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W));
// Ht-1*(R^T)
auto Ht_R = std::make_shared<op::Dot>(H_t, builder::transpose(R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto i_t = add(Xt_W, add(Ht_R, bias));
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
i_t = m_activation_f(clip(i_t));
return {i_t};
}
shared_ptr<Node> op::RNNCell::get_bias() const
{
shared_ptr<Node> bias;
// Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(get_argument(4), 2);
bias = b_W_R.at(0) + b_W_R.at(1);
return bias;
}
void op::RNNCell::add_default_bias_input()
{
shared_ptr<Node> B =
op::Constant::create(input(0).get_element_type(),
Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(4, B->output(0));
}
shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 4)
{
return make_shared<RNNCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip());
}
else if (new_args.size() == 5)
{
return make_shared<RNNCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
get_hidden_size(),
new_args.at(4),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip());
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for RNN cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
///
/// Note this class represents only single *cell* and not whole RNN *layer*.
///
class RNNCell : public util::FusedOp, public util::RNNCellBase
{
public:
///
/// \brief Constructs RNNCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size);
///
/// \brief Constructs RNNCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip);
///
/// \brief Constructs RNNCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape: [2*hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
std::shared_ptr<Node> get_bias() const;
/// brief Add and initialize bias input to all zeros.
void add_default_bias_input();
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
static constexpr std::size_t s_gates_count{1};
};
}
}
......@@ -28,12 +28,14 @@ NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
......
......@@ -17,6 +17,14 @@
#include <algorithm>
#include <iterator>
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/util.hpp"
......@@ -60,3 +68,34 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
return afunc;
}
shared_ptr<Node> op::util::RNNCellBase::add(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::sub(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::mul(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::clip(const shared_ptr<Node>& data) const
{
if (m_clip == 0.f)
{
return data;
}
return make_shared<op::Clamp>(data, -m_clip, m_clip);
}
......@@ -17,9 +17,11 @@
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
namespace ngraph
......@@ -71,10 +73,48 @@ namespace ngraph
/// \return The object representing activation function.
///
ActivationFunction get_activation_function(std::size_t idx) const;
///
/// \brief Creates node with element-wise add operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise add operation.
///
static std::shared_ptr<Node> add(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise subtract operation.
///
static std::shared_ptr<Node> sub(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise multiply operation.
///
static std::shared_ptr<Node> mul(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
/// \param[in] data The input tensor for clipping.
///
/// \return Node with element-wise clip operation.
///
std::shared_ptr<Node> clip(const std::shared_ptr<Node>& data) const;
private:
std::size_t m_hidden_size = 0.f;
float m_clip = 0.f;
const std::size_t m_hidden_size;
const float m_clip;
const std::vector<std::string> m_activations;
const std::vector<float> m_activation_alpha;
const std::vector<float> m_activation_beta;
......
......@@ -14,9 +14,12 @@
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "dyn_elimination.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp"
......@@ -34,6 +37,7 @@ pass::DynElimination::DynElimination()
{
construct_transpose();
construct_broadcast();
construct_dyn_slice();
construct_dyn_reshape();
construct_range();
}
......@@ -367,7 +371,7 @@ static SlicePlan make_plan(const Shape& input_shape,
return p;
}
void pass::DynElimination::construct_dyn_reshape()
void pass::DynElimination::construct_dyn_slice()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label =
......@@ -436,10 +440,53 @@ void pass::DynElimination::construct_dyn_reshape()
};
auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape");
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynSlice");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
}
void pass::DynElimination::construct_dyn_reshape()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto shape_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(data_arg_label, shape_arg_label);
auto dyn_reshape_callback = [data_arg_label, shape_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto shape_arg = static_pointer_cast<op::Constant>(pattern_map[shape_arg_label]);
auto dyn_reshape_node = static_pointer_cast<op::DynReshape>(m.get_match_root());
// TODO(amprocte): Can't handle the case where data rank is dynamic even if we know the
// output shape, because static Reshape requries an axis permutation (here an identity) to
// be given. See if we can come up with a workaround.
if (data_arg->get_output_partial_shape(0).rank().is_dynamic())
{
return false;
}
if (dyn_reshape_node->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto& result_shape = dyn_reshape_node->get_output_shape(0);
AxisVector perm(size_t(data_arg->get_output_partial_shape(0).rank()));
std::iota(perm.begin(), perm.end(), 0);
auto replacement = std::make_shared<op::Reshape>(data_arg, perm, result_shape);
replace_node(dyn_reshape_node, replacement);
return true;
};
auto dyn_reshape_matcher =
make_shared<pattern::Matcher>(dyn_reshape, "DynElimination.DynReshape");
add_matcher(dyn_reshape_matcher, dyn_reshape_callback, all_pass_property_off);
}
template <typename T>
std::shared_ptr<op::Constant>
make_range_replacement_integral(const element::Type& et,
......
......@@ -31,6 +31,7 @@ namespace ngraph
private:
void construct_transpose();
void construct_broadcast();
void construct_dyn_slice();
void construct_dyn_reshape();
void construct_range();
};
......
......@@ -103,17 +103,23 @@ shared_ptr<runtime::Executable>
#endif
shared_ptr<runtime::Executable> rc;
auto it = m_exec_map.find(func);
if (it != m_exec_map.end())
// we will protect the access to map (m_exec_map) across multiple threads by creating a lock_gaurd
// m_exec_map_mutex will be released once the object `guard` goes out of scope
{
rc = it->second;
std::lock_guard<std::mutex> guard(m_exec_map_mutex);
auto it = m_exec_map.find(func);
if (it != m_exec_map.end())
{
rc = it->second;
return rc;
}
}
else
rc = make_shared<CPU_Executable>(func, pass_config, performance_counters_enabled);
{
rc = make_shared<CPU_Executable>(func, pass_config, performance_counters_enabled);
std::lock_guard<std::mutex> guard(m_exec_map_mutex);
m_exec_map.insert({func, rc});
return rc;
}
return rc;
}
runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func,
......@@ -156,6 +162,7 @@ bool runtime::cpu::CPU_Executable::call(const vector<shared_ptr<runtime::Tensor>
void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Executable> exec)
{
std::lock_guard<std::mutex> guard(m_exec_map_mutex);
for (auto it = m_exec_map.begin(); it != m_exec_map.end(); ++it)
{
if (it->second == exec)
......
......@@ -18,6 +18,7 @@
#include <map>
#include <memory>
#include <mutex>
#include "cpu_backend_visibility.h"
#include "ngraph/pass/pass_config.hpp"
......@@ -63,6 +64,9 @@ namespace ngraph
bool is_supported_property(const Property prop) const override;
private:
// this mutex will be used to protect the addition and deletion
// of function to m_exec_map across multiple threads
std::mutex m_exec_map_mutex;
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Executable>>
m_exec_map;
};
......
......@@ -87,11 +87,13 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
......@@ -2070,6 +2072,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GenerateMask:
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::GRUCell:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell:
......@@ -2077,6 +2080,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu:
case OP_TYPEID::Passthrough:
case OP_TYPEID::RNNCell:
case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedConvolutionBias:
......@@ -2195,11 +2199,13 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::Gemm:
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::GRUCell:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth:
......
......@@ -259,6 +259,12 @@ backwards_softmax_underflow
backwards_softmax_3d
batch_mat_mul_forward
dot_matrix_2x0_0x2
rnn_cell_no_bias
rnn_cell_bias_clip
rnn_cell_activation_function
gru_cell_bias_clip
gru_cell_linear_before_reset
gru_cell_activation_function
# dgkutnic ww24.5: these tests are to be triaged by the PlaidML team
# ww25.2: re-scrubbed this list of tests after fixing check_inputs
......
......@@ -79,12 +79,14 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
......@@ -1230,13 +1232,6 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::GRN>(args[0], bias);
break;
}
case OP_TYPEID::HardSigmoid:
{
auto alpha = node_js.at("alpha").get<float>();
auto beta = node_js.at("beta").get<float>();
node = make_shared<op::HardSigmoid>(args[0], alpha, beta);
break;
}
case OP_TYPEID::GroupConvolution:
{
auto window_movement_strides =
......@@ -1283,6 +1278,35 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
output_shape);
break;
}
case OP_TYPEID::GRUCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto linear_before_reset = node_js.at("linear_before_reset").get<bool>();
node = make_shared<op::GRUCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
args[4],
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset);
break;
}
case OP_TYPEID::HardSigmoid:
{
auto alpha = node_js.at("alpha").get<float>();
auto beta = node_js.at("beta").get<float>();
node = make_shared<op::HardSigmoid>(args[0], alpha, beta);
break;
}
case OP_TYPEID::LeakyRelu:
{
node = make_shared<op::LeakyRelu>(args[0], args[1]);
......@@ -1664,6 +1688,25 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
}
case OP_TYPEID::RNNCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
node = make_shared<op::RNNCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
args[4],
activations,
activation_alpha,
activation_beta,
clip);
break;
}
case OP_TYPEID::ScalarConstantLike:
{
double value = node_js.at("value").get<double>();
......@@ -2340,11 +2383,15 @@ json JSONSerializer::serialize_node(const Node& n)
node["bias"] = tmp->get_bias();
break;
}
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::GRUCell:
{
auto tmp = dynamic_cast<const op::HardSigmoid*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
auto tmp = dynamic_cast<const op::GRUCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["linear_before_reset"] = tmp->get_linear_before_reset();
break;
}
case OP_TYPEID::GroupConvolution:
......@@ -2372,6 +2419,13 @@ json JSONSerializer::serialize_node(const Node& n)
node["output_shape"] = tmp->get_output_shape();
break;
}
case OP_TYPEID::HardSigmoid:
{
auto tmp = dynamic_cast<const op::HardSigmoid*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
break;
}
case OP_TYPEID::LeakyRelu: { break;
}
case OP_TYPEID::Less:
......@@ -2662,6 +2716,16 @@ json JSONSerializer::serialize_node(const Node& n)
node["sequence_axis"] = tmp->get_sequence_axis();
break;
}
case OP_TYPEID::RNNCell:
{
auto tmp = dynamic_cast<const op::RNNCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
break;
}
case OP_TYPEID::ScalarConstantLike:
{
auto tmp = dynamic_cast<const op::ScalarConstantLikeBase*>(&n);
......
This diff is collapsed.
......@@ -132,6 +132,30 @@ TEST(dyn_elimination, slice)
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2}));
}
TEST(dyn_elimination, reshape)
{
auto input_arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto shape_arg = make_shared<op::Constant>(element::i64, Shape{3}, vector<int64_t>{0, 6, -1});
auto r = make_shared<op::DynReshape>(input_arg, shape_arg, true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{2, 6, 32}));
auto f = make_shared<Function>(r, ParameterVector{input_arg});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 6, 32}));
}
TEST(dyn_elimination, range)
{
auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0});
......
......@@ -365,3 +365,94 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, range)
ASSERT_EQ(results, test.expected_result);
}
}
NGRAPH_TEST(dynamic_${BACKEND_NAME}, reshape)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto build_graph = [&backend](bool zero_flag) {
// Create a graph for f(x,shape) = DynReshape(x,shape,zero_flag=zero_flag).
auto x = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto dyn_reshape = make_shared<op::DynReshape>(x, shape, zero_flag);
EXPECT_TRUE(dyn_reshape->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
auto f = make_shared<Function>(NodeVector{dyn_reshape}, ParameterVector{x, shape});
auto ex = backend->compile(f);
return ex;
};
auto t_r = backend->create_dynamic_tensor(element::i32, PartialShape::dynamic());
auto ex_flag_off = build_graph(false);
auto ex_flag_on = build_graph(true);
std::vector<std::tuple<bool, Shape, std::vector<int32_t>, std::vector<int64_t>, Shape>> tests;
tests.emplace_back(make_tuple(
false, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{6}, Shape{6}));
tests.emplace_back(make_tuple(
true, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{6}, Shape{6}));
tests.emplace_back(make_tuple(
false, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{-1}, Shape{6}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{2, -1},
Shape{2, 3}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, -1},
Shape{3, 2}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, 2, -1},
Shape{3, 2, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, 2, -1},
Shape{3, 2, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{0, 0, -1},
Shape{2, 3, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{2, 0, -1},
Shape{2, 3, 1}));
tests.emplace_back(make_tuple(
true, Shape{0, 3, 4}, vector<int32_t>{}, vector<int64_t>{3, -1, 2}, Shape{3, 0, 2}));
for (auto& test : tests)
{
bool zero_flag = get<0>(test);
const Shape& in_shape = get<1>(test);
const std::vector<int32_t>& data = get<2>(test);
const std::vector<int64_t>& dims = get<3>(test);
const Shape& out_shape = get<4>(test);
auto t_x = backend->create_tensor(element::i32, in_shape);
auto t_shape = backend->create_tensor(element::i64, Shape{dims.size()});
copy_data(t_x, data);
copy_data(t_shape, dims);
auto ex = zero_flag ? ex_flag_on : ex_flag_off;
ex->call_with_validate({t_r}, {t_x, t_shape});
ASSERT_EQ(t_r->get_element_type(), element::i32);
ASSERT_EQ(t_r->get_shape(), out_shape);
auto results = read_vector<int32_t>(t_r);
ASSERT_EQ(results, data);
}
}
......@@ -15763,3 +15763,164 @@ INSTANTIATE_TEST_CASE_P(type_prop,
::testing::Values(RangeParams{0, 1, 0.25, PartialShape{4}},
RangeParams{-1, 1, 0.25, PartialShape{8}},
RangeParams{-1, 0.875, 0.25, PartialShape{8}}));
TEST(type_prop, rnn_cell)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
EXPECT_EQ(rnn_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(rnn_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
}
TEST(type_prop, rnn_cell_invalid_input)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor W must have shape"));
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor R must have shape"));
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
}
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, W, R, H_t, hidden_size, B);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
}
TEST(type_prop, gru_cell)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
EXPECT_EQ(gru_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(gru_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
}
TEST(type_prop, gru_cell_invalid_input)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor W must have shape"));
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor R must have shape"));
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
}
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, W, R, H_t, hidden_size, B);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment