Unverified Commit 69a2a323 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into ldurka/onnxruntime_ci

parents bcf30a63 727c55d3
Contributor Guidelines Contributor Guidelines
====================== ======================
https://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html The latest version of this file can be found at:
https://ngraph.nervanasys.com/docs/latest/project/contribution-guide.html
License License
......
...@@ -142,6 +142,8 @@ set (SRC ...@@ -142,6 +142,8 @@ set (SRC
op/experimental/dyn_broadcast.hpp op/experimental/dyn_broadcast.hpp
op/experimental/dyn_pad.cpp op/experimental/dyn_pad.cpp
op/experimental/dyn_pad.hpp op/experimental/dyn_pad.hpp
op/experimental/dyn_replace_slice.cpp
op/experimental/dyn_replace_slice.hpp
op/experimental/dyn_reshape.cpp op/experimental/dyn_reshape.cpp
op/experimental/dyn_reshape.hpp op/experimental/dyn_reshape.hpp
op/experimental/dyn_slice.cpp op/experimental/dyn_slice.cpp
...@@ -318,6 +320,8 @@ set (SRC ...@@ -318,6 +320,8 @@ set (SRC
op/fused/group_conv.cpp op/fused/group_conv.cpp
op/fused/group_conv_transpose.hpp op/fused/group_conv_transpose.hpp
op/fused/group_conv_transpose.cpp op/fused/group_conv_transpose.cpp
op/fused/gru_cell.cpp
op/fused/gru_cell.hpp
op/fused/leaky_relu.cpp op/fused/leaky_relu.cpp
op/fused/leaky_relu.hpp op/fused/leaky_relu.hpp
op/fused/lstm_cell.cpp op/fused/lstm_cell.cpp
...@@ -328,6 +332,8 @@ set (SRC ...@@ -328,6 +332,8 @@ set (SRC
op/fused/normalize.hpp op/fused/normalize.hpp
op/fused/prelu.cpp op/fused/prelu.cpp
op/fused/prelu.hpp op/fused/prelu.hpp
op/fused/rnn_cell.cpp
op/fused/rnn_cell.hpp
op/fused/scale_shift.cpp op/fused/scale_shift.cpp
op/fused/scale_shift.hpp op/fused/scale_shift.hpp
op/fused/shuffle_channels.cpp op/fused/shuffle_channels.cpp
......
...@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type, void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape) const PartialShape& pshape)
{ {
NGRAPH_CHECK(pshape.all_non_negative(),
"set_tensor_type called on a PartialShape containing negative dimensions: ",
pshape);
if (pshape.is_static()) if (pshape.is_static())
{ {
m_shape = pshape.to_shape(); m_shape = pshape.to_shape();
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using namespace ngraph; using namespace ngraph;
Dimension::Dimension(size_t dimension) Dimension::Dimension(int64_t dimension)
: m_dimension(dimension) : m_dimension(dimension)
{ {
if (dimension == s_dynamic_val) if (dimension == s_dynamic_val)
...@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{ {
if (dimension.is_static()) if (dimension.is_static())
{ {
return (str << size_t(dimension)); return (str << int64_t(dimension));
} }
else else
{ {
...@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension Dimension::operator+(const Dimension& dim) const Dimension Dimension::operator+(const Dimension& dim) const
{ {
return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension + int64_t(dim) : Dimension::dynamic());
} }
Dimension Dimension::operator-(const Dimension& dim) const Dimension Dimension::operator-(const Dimension& dim) const
{ {
return (is_static() && dim.is_static() ? m_dimension - size_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension - int64_t(dim) : Dimension::dynamic());
} }
Dimension Dimension::operator*(const Dimension& dim) const Dimension Dimension::operator*(const Dimension& dim) const
{ {
return ((is_static() && dim.is_static()) return ((is_static() && dim.is_static())
? m_dimension * size_t(dim) ? m_dimension * int64_t(dim)
: (is_static() && m_dimension == 0) : (is_static() && m_dimension == 0)
? 0 ? 0
: (dim.is_static() && size_t(dim) == 0) ? 0 : Dimension::dynamic()); : (dim.is_static() && int64_t(dim) == 0) ? 0 : Dimension::dynamic());
} }
bool Dimension::compatible(const Dimension& d) const bool Dimension::compatible(const Dimension& d) const
{ {
return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d)); return (is_dynamic() || d.is_dynamic() || m_dimension == int64_t(d));
} }
bool Dimension::relaxes(const Dimension& d) const bool Dimension::relaxes(const Dimension& d) const
{ {
return (is_dynamic() || (d.is_static() && size_t(*this) == size_t(d))); return (is_dynamic() || (d.is_static() && int64_t(*this) == int64_t(d)));
} }
bool Dimension::refines(const Dimension& d) const bool Dimension::refines(const Dimension& d) const
{ {
return (d.is_dynamic() || (is_static() && size_t(d) == size_t(*this))); return (d.is_dynamic() || (is_static() && int64_t(d) == int64_t(*this)));
} }
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
...@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) ...@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
dst = d1; dst = d1;
return true; return true;
} }
else if (size_t(d1) != size_t(d2)) else if (int64_t(d1) != int64_t(d2))
{ {
return false; return false;
} }
...@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens ...@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
else if (d1.is_dynamic() || d2.is_dynamic()) else if (d1.is_dynamic() || d2.is_dynamic())
{ {
// One static. Set dst to static size if >1 // One static. Set dst to static size if >1
auto ds = d1.is_dynamic() ? size_t(d2) : size_t(d1); auto ds = d1.is_dynamic() ? int64_t(d2) : int64_t(d1);
dst = (ds > 1) ? ds : Dimension::dynamic(); dst = (ds > 1) ? ds : Dimension::dynamic();
return true; return true;
} }
else else
{ {
// Static sizes. Both match or one of them is 1. // Static sizes. Both match or one of them is 1.
if (size_t(d1) == size_t(d2) || size_t(d1) == 1 || size_t(d2) == 1) if (int64_t(d1) == int64_t(d2) || int64_t(d1) == 1 || int64_t(d2) == 1)
{ {
dst = std::max(size_t(d1), size_t(d2)); dst = std::max(int64_t(d1), int64_t(d2));
return true; return true;
} }
else else
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime), /// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object. /// in a shape or shape-like object.
/// ///
/// Static dimensions may be implicitly converted from size_t. A dynamic dimension is /// Static dimensions may be implicitly converted from int64_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic(). /// constructed with Dimension() or Dimension::dynamic().
/// ///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// \param dimension Value of the dimension. Must not be equal to /// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val. /// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val. /// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension); Dimension(int64_t dimension);
/// \brief Construct a dynamic dimension. /// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_dynamic_val; } Dimension() { m_dimension = s_dynamic_val; }
...@@ -46,25 +46,30 @@ namespace ngraph ...@@ -46,25 +46,30 @@ namespace ngraph
/// \brief Check whether this dimension is dynamic. /// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`. /// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); } bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static. /// \brief Convert this dimension to `int64_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic. /// \throws std::invalid_argument If this dimension is dynamic.
explicit operator size_t() const explicit operator int64_t() const
{ {
if (is_dynamic()) if (is_dynamic())
{ {
throw std::invalid_argument("Cannot convert dynamic dimension to size_t"); throw std::invalid_argument("Cannot convert dynamic dimension to int64_t");
} }
return m_dimension; return m_dimension;
} }
/// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static. /// \brief Convert this dimension to `size_t`. This dimension must be static and
/// \throws std::invalid_argument If this dimension is dynamic. /// non-negative.
explicit operator ptrdiff_t() const /// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit operator size_t() const
{ {
if (is_dynamic()) if (is_dynamic())
{ {
throw std::invalid_argument("Cannot convert dynamic dimension to ptrdiff_t"); throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
if (m_dimension < 0)
{
throw std::invalid_argument("Cannot convert negative dimension to size_t");
} }
return static_cast<ptrdiff_t>(m_dimension); return m_dimension;
} }
/// \brief Check whether this dimension represents the same scheme as the argument (both /// \brief Check whether this dimension represents the same scheme as the argument (both
...@@ -75,7 +80,7 @@ namespace ngraph ...@@ -75,7 +80,7 @@ namespace ngraph
bool same_scheme(const Dimension& dim) const bool same_scheme(const Dimension& dim) const
{ {
return (is_dynamic() && dim.is_dynamic()) || return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim)); (is_static() && dim.is_static() && m_dimension == int64_t(dim));
} }
/// \brief Try to merge two Dimension objects together. /// \brief Try to merge two Dimension objects together.
...@@ -128,25 +133,25 @@ namespace ngraph ...@@ -128,25 +133,25 @@ namespace ngraph
/// \return A dynamic dimension. /// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); } static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension. /// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{(std::numeric_limits<size_t>::max())}; static const int64_t s_dynamic_val{(std::numeric_limits<int64_t>::max())};
/// \brief Addition operator for Dimension. /// \brief Addition operator for Dimension.
/// \param dim Right operand for addition. /// \param dim Right operand for addition.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static /// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)+size_t(dim)`. /// dimension with value `int64_t(*this)+in64_t(dim)`.
Dimension operator+(const Dimension& dim) const; Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension. /// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction. /// \param dim Right operand for subtraction.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static /// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)-size_t(dim)`. /// dimension with value `int64_t(*this)-int64_t(dim)`.
Dimension operator-(const Dimension& dim) const; Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension. /// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton. /// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if /// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
/// either of `*this` or `dim` is dynamic; else, a static dimension with value /// either of `*this` or `dim` is dynamic; else, a static dimension with value
/// `size_t(*this)*size_t(dim)`. /// `int64_t(*this)*int64_t(dim)`.
Dimension operator*(const Dimension& dim) const; Dimension operator*(const Dimension& dim) const;
/// \brief Add-into operator for Dimension. /// \brief Add-into operator for Dimension.
...@@ -160,7 +165,7 @@ namespace ngraph ...@@ -160,7 +165,7 @@ namespace ngraph
private: private:
// The actual numerical value of the dimension. s_dynamic_val is a special case, // The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension. // representing a dynamic dimension.
size_t m_dimension; int64_t m_dimension;
}; };
/// \brief Insert a human-readable representation of a dimension into an output stream. /// \brief Insert a human-readable representation of a dimension into an output stream.
...@@ -168,6 +173,6 @@ namespace ngraph ...@@ -168,6 +173,6 @@ namespace ngraph
/// \param dimension The dimension to be inserted into `str`. /// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion. /// \return A reference to `str` after insertion.
/// ///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`. /// Inserts the string `?` if `dimension` is dynamic; else inserts `int64_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension); std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
} }
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
......
...@@ -89,6 +89,7 @@ ...@@ -89,6 +89,7 @@
#include "ngraph/op/experimental/batch_mat_mul.hpp" #include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp" #include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
...@@ -105,12 +106,14 @@ ...@@ -105,12 +106,14 @@
#include "ngraph/op/fused/grn.hpp" #include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp" #include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp" #include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp" #include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp" #include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp" #include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp" #include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& replacement,
const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds,
const shared_ptr<Node>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
: Op("DynReplaceSlice",
check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides}))
, m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis)
, m_shrink_axis(shrink_axis)
, m_ellipsis_mask(ellipsis_mask)
{
constructor_validate_and_infer_types();
}
void op::DynReplaceSlice::validate_and_infer_types()
{
auto arg_et = get_input_element_type(0);
auto replacement_et = get_input_element_type(1);
auto lower_bounds_et = get_input_element_type(2);
auto upper_bounds_et = get_input_element_type(3);
auto strides_et = get_input_element_type(4);
element::Type result_et;
// check data types
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, arg_et, replacement_et),
"Argument element type is not compatible with replacement element type");
NODE_VALIDATION_CHECK(this,
lower_bounds_et.compatible(element::Type_t::i64),
"Lower bounds must have element type i64.");
NODE_VALIDATION_CHECK(this,
upper_bounds_et.compatible(element::Type_t::i64),
"Upper bounds must have element type i64.");
NODE_VALIDATION_CHECK(
this, strides_et.compatible(element::Type_t::i64), "Strides must have element type i64");
// check shapes
auto arg_shape = get_input_partial_shape(0);
auto replacement_shape = get_input_partial_shape(1);
auto lower_bounds_shape = get_input_partial_shape(2);
auto upper_bounds_shape = get_input_partial_shape(3);
auto strides_shape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
lower_bounds_shape.rank().compatible(1),
"Lower bounds shape must have rank 1, got ",
lower_bounds_shape.rank(),
".");
NODE_VALIDATION_CHECK(this,
upper_bounds_shape.rank().compatible(1),
"Upper bounds shape must have rank 1, got ",
upper_bounds_shape.rank(),
".");
NODE_VALIDATION_CHECK(this,
strides_shape.rank().compatible(1),
"Strides shape must have rank 1, got ",
strides_shape.rank(),
".");
PartialShape attrs_shape{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(this,
(lower_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, lower_bounds_shape)) &&
(upper_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, upper_bounds_shape)) &&
(strides_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, strides_shape)),
"Shapes for lower bounds, upper bounds, and strides do not match");
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2));
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(3));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(4));
// TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes.
auto slice_shape = PartialShape::dynamic();
if (lower_bounds && upper_bounds && strides)
{
slice_shape = infer_slice_shape(this,
get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask);
}
NODE_VALIDATION_CHECK(this,
slice_shape.compatible(replacement_shape),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: ",
slice_shape,
")");
set_output_type(0, result_et, arg_shape);
}
shared_ptr<Node> op::DynReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<DynReplaceSlice>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask);
}
void op::DynReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("generate_adjoints not implemented for DynReplaceSlice");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Takes a slice of an input tensor, i.e., the sub-tensor that resides within a bounding box, optionally with stride.
class DynReplaceSlice : public Op
{
public:
/// \brief Constructs a dynamic tensor replace-slice operation.
///
/// \param arg The tensor in which to replace the slice.
/// \param replacement Data to copy to the slice for replacement.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
/// \param lower_bounds_mask Ignores lower_bounds for axis with the mask set
/// \param upper_bounds_mask Ignores upper_bounds for axis with the mask set
/// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position
DynReplaceSlice(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& replacement,
const std::shared_ptr<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds,
const std::shared_ptr<Node>& strides,
const AxisSet& lower_bounds_mask = AxisSet{},
const AxisSet& upper_bounds_mask = AxisSet{},
const AxisSet& new_axis = AxisSet{},
const AxisSet& shrink_axis = AxisSet{},
const AxisSet& ellipsis_mask = AxisSet{});
const AxisSet& get_lower_bounds_mask() const { return m_lower_bounds_mask; }
const AxisSet& get_upper_bounds_mask() const { return m_upper_bounds_mask; }
const AxisSet& get_new_axis() const { return m_new_axis; }
const AxisSet& get_shrink_axis() const { return m_shrink_axis; }
const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
private:
/// Helper method to compute output shape
Shape compute_output_shape() const;
AxisSet m_lower_bounds_mask;
AxisSet m_upper_bounds_mask;
AxisSet m_new_axis;
AxisSet m_shrink_axis;
AxisSet m_ellipsis_mask;
};
}
}
...@@ -99,7 +99,8 @@ void op::DynReshape::validate_and_infer_types() ...@@ -99,7 +99,8 @@ void op::DynReshape::validate_and_infer_types()
if (out_shape_val[i] == 0 && m_zero_flag) if (out_shape_val[i] == 0 && m_zero_flag)
{ {
// Copy input_shape[i] for zero values // Copy input_shape[i] for zero values
NGRAPH_CHECK(i < input_shape.size()); NODE_VALIDATION_CHECK(
this, i < input_shape.size(), "'0' dimension is out of range");
partial_shape[i] = Dimension(input_shape[i]); partial_shape[i] = Dimension(input_shape[i]);
output_elements *= input_shape[i]; output_elements *= input_shape[i];
} }
...@@ -119,12 +120,21 @@ void op::DynReshape::validate_and_infer_types() ...@@ -119,12 +120,21 @@ void op::DynReshape::validate_and_infer_types()
// input elements // input elements
if (output_elements == 0) if (output_elements == 0)
{ {
NGRAPH_CHECK(input_elements == 0); // TODO(amprocte): Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_elements == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
partial_shape[negative_dim] = Dimension(0); partial_shape[negative_dim] = Dimension(0);
} }
else else
{ {
NGRAPH_CHECK(input_elements % output_elements == 0); NODE_VALIDATION_CHECK(
this,
input_elements % output_elements == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
partial_shape[negative_dim] = Dimension(input_elements / output_elements); partial_shape[negative_dim] = Dimension(input_elements / output_elements);
} }
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory> #include <memory>
...@@ -42,142 +43,6 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg, ...@@ -42,142 +43,6 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
Shape op::DynSlice::compute_output_shape() const
{
auto input_shape = get_input_partial_shape(0).to_shape();
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(1));
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(3));
if (lower_bounds && upper_bounds && strides)
{
auto lb = lower_bounds->get_vector<int64_t>();
auto ub = upper_bounds->get_vector<int64_t>();
auto str = strides->get_vector<int64_t>();
int max_dims = input_shape.size() + m_new_axis.size();
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(
this,
lb.size() == ub.size(),
"Lower bounds and Upper bounds needs to have same number of values");
}
if (lb.size() && str.size())
{
NODE_VALIDATION_CHECK(this,
lb.size() == str.size(),
"Lower bounds and strides needs to have same number of values");
}
if (ub.size() && str.size())
{
NODE_VALIDATION_CHECK(this,
ub.size() == str.size(),
"Upper bounds and strides needs to have same number of values");
}
int bounds_size =
lb.size() ? lb.size() : (ub.size() ? ub.size() : (str.size() ? str.size() : 0));
NODE_VALIDATION_CHECK(
this, m_ellipsis_mask.size() <= 1, "Ellipsis mask cannot specify more than one axis");
int ellipsis_pos1 = m_ellipsis_mask.size() ? *m_ellipsis_mask.begin() : max_dims;
int ellipsis_pos2 = max_dims;
bounds_size -= ellipsis_pos1;
if (bounds_size > 0 && (max_dims - bounds_size) > ellipsis_pos1)
{
ellipsis_pos2 = max_dims - bounds_size;
}
std::vector<int> begin_dms(max_dims, 0);
std::vector<int> end_dms(max_dims, -1);
std::vector<int> stride_dms(max_dims, 1);
int i, j, k, bj, ej, sj;
Shape out_dims;
for (i = 0, j = 0, k = 0, bj = 0, ej = 0, sj = 0; i < max_dims; i++)
{
if (i >= ellipsis_pos1 && i < ellipsis_pos2)
{
if (m_new_axis.find(i) == m_new_axis.end())
{
end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : input_shape[j++] + end_dms[i];
}
else
{
end_dms[i] = begin_dms[i];
}
out_dims.push_back(
static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
static_cast<float>(abs(stride_dms[i])))));
k = ellipsis_pos1;
continue;
}
stride_dms[i] = (str.size() > sj && str[sj] != 0) ? str[sj++] : 1;
// Use lower_bounds if mask is not set
if (m_lower_bounds_mask.find(j) == m_lower_bounds_mask.end())
{
begin_dms[i] = lb.size() > bj ? lb[bj] : (stride_dms[i] > 0 ? 0 : -1);
}
else
{
begin_dms[i] = stride_dms[i] > 0 ? 0 : -1;
}
bj++;
begin_dms[i] = begin_dms[i] >= 0 ? begin_dms[i] : input_shape[j] + begin_dms[i];
// Clipping 'begin'
begin_dms[i] =
(begin_dms[i] < 0) ? 0 : (begin_dms[i] >= input_shape[j] ? input_shape[j] - 1
: begin_dms[i]);
// Use upper_bounds if mask is not set
if (m_upper_bounds_mask.find(j) == m_upper_bounds_mask.end())
{
int end_dms_tmp =
ub.size() > ej ? (stride_dms[i] > 0 ? ub[ej] - 1 : ub[ej] + 1) : end_dms[i];
end_dms[i] = ub.size() > ej ? end_dms_tmp : (stride_dms[i] > 0 ? -1 : 0);
}
else
{
end_dms[i] = stride_dms[i] > 0 ? -1 : 0;
}
ej++;
end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : input_shape[j] + end_dms[i];
// Clipping 'end'
end_dms[i] = (end_dms[i] < 0) ? 0 : (end_dms[i] >= input_shape[j] ? input_shape[j] - 1
: end_dms[i]);
if (m_new_axis.find(i) == m_new_axis.end())
{
j++;
}
else
{
end_dms[i] = 0;
}
if (m_shrink_axis.find(k) != m_shrink_axis.end())
{
end_dms[i] = begin_dms[i];
}
else
{
out_dims.push_back(
static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
static_cast<float>(abs(stride_dms[i])))));
}
k++;
}
return out_dims;
}
return Shape{};
}
void op::DynSlice::validate_and_infer_types() void op::DynSlice::validate_and_infer_types()
{ {
auto lower_bounds_et = get_input_element_type(1); auto lower_bounds_et = get_input_element_type(1);
...@@ -219,17 +84,24 @@ void op::DynSlice::validate_and_infer_types() ...@@ -219,17 +84,24 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(2); set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3); set_input_is_relevant_to_shape(3);
if (get_input_partial_shape(0).is_static()) auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(1));
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(3));
if (lower_bounds && upper_bounds && strides)
{ {
auto shape = compute_output_shape(); set_output_type(0,
if (shape != Shape{}) get_input_element_type(0),
{ infer_slice_shape(this,
set_output_type(0, get_input_element_type(0), shape); get_input_partial_shape(0),
} lower_bounds->get_vector<int64_t>(),
else upper_bounds->get_vector<int64_t>(),
{ strides->get_vector<int64_t>(),
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank())); m_lower_bounds_mask,
} m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask));
} }
else else
{ {
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for GRU cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
///
/// Note this class represents only single *cell* and not whole GRU *layer*.
///
class GRUCell : public util::FusedOp, public util::RNNCellBase
{
public:
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip,
bool linear_before_reset);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f,
bool linear_before_reset = false);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_linear_before_reset() const { return m_linear_before_reset; }
private:
/// brief Add and initialize bias input to all zeros.
void add_default_bias_input();
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
util::ActivationFunction m_activation_g;
static constexpr std::size_t s_gates_count{3};
///
/// \brief Control whether or not apply the linear transformation.
///
/// \note The linear transformation may be applied when computing the output of hidden gate.
/// It's done before multiplying by the output of the reset gate.
///
bool m_linear_before_reset;
};
}
}
...@@ -24,11 +24,6 @@ ...@@ -24,11 +24,6 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -36,46 +31,6 @@ ...@@ -36,46 +31,6 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
// ------------- HELPER FUNCTIONS ---------------------------------------------
static shared_ptr<Node> add(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
static shared_ptr<Node> sub(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
static shared_ptr<Node> mul(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
static shared_ptr<Node> clip(const shared_ptr<Node>& data, float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
size_t size = shape_size(data->get_shape());
const shared_ptr<Node> min_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, min_val));
const shared_ptr<Node> max_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, max_val));
return make_shared<op::Minimum>(max_val_node, make_shared<op::Maximum>(data, min_val_node));
}
// ------------- LSTM_CELL ----------------------------------------------------
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W, const shared_ptr<Node>& W,
const shared_ptr<Node>& R, const shared_ptr<Node>& R,
...@@ -166,18 +121,18 @@ void op::LSTMCell::pre_validate_and_infer_types() ...@@ -166,18 +121,18 @@ void op::LSTMCell::pre_validate_and_infer_types()
const Shape& ct_shape{ct_pshape.to_shape()}; const Shape& ct_shape{ct_pshape.to_shape()};
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
(w_shape == Shape{m_gates_count * get_hidden_size(), input_size}), (w_shape == Shape{s_gates_count * get_hidden_size(), input_size}),
"Input tensor W must have shape (", "Input tensor W must have shape (",
m_gates_count * get_hidden_size(), s_gates_count * get_hidden_size(),
", ", ", ",
input_size, input_size,
"). Actual shape is:", "). Actual shape is:",
w_shape, w_shape,
"."); ".");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
(r_shape == Shape{m_gates_count * get_hidden_size(), get_hidden_size()}), (r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (", "Input tensor R must have shape (",
m_gates_count * get_hidden_size(), s_gates_count * get_hidden_size(),
", ", ", ",
get_hidden_size(), get_hidden_size(),
"). Actual shape is:", "). Actual shape is:",
...@@ -213,7 +168,7 @@ void op::LSTMCell::pre_validate_and_infer_types() ...@@ -213,7 +168,7 @@ void op::LSTMCell::pre_validate_and_infer_types()
const Shape& p_shape{p_pshape.to_shape()}; const Shape& p_shape{p_pshape.to_shape()};
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * m_gates_count * get_hidden_size()}), (b_shape == Shape{2 * s_gates_count * get_hidden_size()}),
"Input tensor B must have shape (", "Input tensor B must have shape (",
8 * get_hidden_size(), 8 * get_hidden_size(),
"). Actual shape is:", "). Actual shape is:",
...@@ -221,9 +176,9 @@ void op::LSTMCell::pre_validate_and_infer_types() ...@@ -221,9 +176,9 @@ void op::LSTMCell::pre_validate_and_infer_types()
"."); ".");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
(p_shape == Shape{m_peepholes_count * get_hidden_size()}), (p_shape == Shape{s_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (", "Input tensor P must have shape (",
m_peepholes_count * get_hidden_size(), s_peepholes_count * get_hidden_size(),
"). Actual shape is:", "). Actual shape is:",
p_shape, p_shape,
"."); ".");
...@@ -295,7 +250,7 @@ NodeVector op::LSTMCell::decompose_op() const ...@@ -295,7 +250,7 @@ NodeVector op::LSTMCell::decompose_op() const
auto c_t = split_gates.at(3); auto c_t = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t)), get_clip())); i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t))));
if (m_input_forget) if (m_input_forget)
{ {
// Couple input with forget gate: 1 - i_t // Couple input with forget gate: 1 - i_t
...@@ -307,14 +262,14 @@ NodeVector op::LSTMCell::decompose_op() const ...@@ -307,14 +262,14 @@ NodeVector op::LSTMCell::decompose_op() const
else else
{ {
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) // f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t)), get_clip())); f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t))));
} }
// ft (.) Ct-1 + it (.) ct // ft (.) Ct-1 + it (.) ct
auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t, get_clip())))); auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) // f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o_t = m_activation_f(clip(add(o_t, mul(p_o, C)), get_clip())); o_t = m_activation_f(clip(add(o_t, mul(p_o, C))));
// ot (.) h(Ct) // ot (.) h(Ct)
auto H = mul(o_t, m_activation_h(clip(C, get_clip()))); auto H = mul(o_t, m_activation_h(clip(C)));
return {H, C}; return {H, C};
} }
...@@ -332,15 +287,15 @@ NodeVector op::LSTMCell::get_peephole_weigths() const ...@@ -332,15 +287,15 @@ NodeVector op::LSTMCell::get_peephole_weigths() const
{ {
shared_ptr<Node> P; shared_ptr<Node> P;
P = get_argument(6); P = get_argument(6);
return builder::split(P, m_peepholes_count); return builder::split(P, s_peepholes_count);
} }
void op::LSTMCell::add_default_bias_input() void op::LSTMCell::add_default_bias_input()
{ {
shared_ptr<Node> B = shared_ptr<Node> B =
op::Constant::create(input(0).get_element_type(), op::Constant::create(input(0).get_element_type(),
Shape{2 * m_gates_count * get_hidden_size()}, Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * m_gates_count * get_hidden_size(), 0.f)); vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(5, B->output(0)); set_argument(5, B->output(0));
} }
...@@ -348,8 +303,8 @@ void op::LSTMCell::add_default_peepholes_input() ...@@ -348,8 +303,8 @@ void op::LSTMCell::add_default_peepholes_input()
{ {
shared_ptr<Node> P = shared_ptr<Node> P =
op::Constant::create(input(0).get_element_type(), op::Constant::create(input(0).get_element_type(),
Shape{m_peepholes_count * get_hidden_size()}, Shape{s_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f)); vector<float>(s_peepholes_count * get_hidden_size(), 0.f));
set_argument(6, P->output(0)); set_argument(6, P->output(0));
} }
......
...@@ -168,8 +168,8 @@ namespace ngraph ...@@ -168,8 +168,8 @@ namespace ngraph
/// ///
bool m_input_forget = false; bool m_input_forget = false;
static constexpr std::size_t m_gates_count{4}; static constexpr std::size_t s_gates_count{4};
static constexpr std::size_t m_peepholes_count{3}; static constexpr std::size_t s_peepholes_count{3};
}; };
} }
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <functional>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size)
: RNNCell(
X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f)
{
}
op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip)
: FusedOp("RNNCell", {X, W, R, H_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)}
{
add_default_bias_input();
constructor_validate_and_infer_types();
}
op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size,
const shared_ptr<Node>& B,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip)
: FusedOp("RNNCell", {X, W, R, H_t, B})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)}
{
constructor_validate_and_infer_types();
}
void op::RNNCell::pre_validate_and_infer_types()
{
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static()),
"RNNCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{get_hidden_size(), input_size}),
"Input tensor W must have shape (",
get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor H_t must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
const auto& b_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(
this, b_pshape.is_static(), "RNNCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * get_hidden_size()}),
"Input tensor B must have shape (",
2 * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
}
NodeVector op::RNNCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i_t - input gate at current time step
// t - time step (t-1 means previous time step)
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The bias tensor for the input gate. Shape: [2 * hidden_size] Concatenation of `[Wb, Rb]`.
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// Xt_W - Input sequence multiplied by weights tensor at current time step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// ---- Equations ----
// f - is activation functions.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// --------------------
std::shared_ptr<Node> X = get_argument(0);
std::shared_ptr<Node> W = get_argument(1);
std::shared_ptr<Node> R = get_argument(2);
std::shared_ptr<Node> H_t = get_argument(3);
std::shared_ptr<Node> bias = get_bias();
// Xt*(W^T)
auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W));
// Ht-1*(R^T)
auto Ht_R = std::make_shared<op::Dot>(H_t, builder::transpose(R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto i_t = add(Xt_W, add(Ht_R, bias));
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
i_t = m_activation_f(clip(i_t));
return {i_t};
}
shared_ptr<Node> op::RNNCell::get_bias() const
{
shared_ptr<Node> bias;
// Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(get_argument(4), 2);
bias = b_W_R.at(0) + b_W_R.at(1);
return bias;
}
void op::RNNCell::add_default_bias_input()
{
shared_ptr<Node> B =
op::Constant::create(input(0).get_element_type(),
Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(4, B->output(0));
}
shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 4)
{
return make_shared<RNNCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip());
}
else if (new_args.size() == 5)
{
return make_shared<RNNCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
get_hidden_size(),
new_args.at(4),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip());
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for RNN cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
///
/// Note this class represents only single *cell* and not whole RNN *layer*.
///
class RNNCell : public util::FusedOp, public util::RNNCellBase
{
public:
///
/// \brief Constructs RNNCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size);
///
/// \brief Constructs RNNCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip);
///
/// \brief Constructs RNNCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape: [2*hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
std::shared_ptr<Node> get_bias() const;
/// brief Add and initialize bias input to all zeros.
void add_default_bias_input();
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
static constexpr std::size_t s_gates_count{1};
};
}
}
...@@ -28,12 +28,14 @@ NGRAPH_OP(Gemm, ngraph::op) ...@@ -28,12 +28,14 @@ NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op) NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op) NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op) NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op) NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LeakyRelu, ngraph::op) NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op) NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op) NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op) NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op) NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op) NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op) NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op) NGRAPH_OP(SpaceToDepth, ngraph::op)
......
...@@ -84,6 +84,7 @@ NGRAPH_OP(Divide, ngraph::op) ...@@ -84,6 +84,7 @@ NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(Dot, ngraph::op) NGRAPH_OP(Dot, ngraph::op)
NGRAPH_OP(DynBroadcast, ngraph::op) NGRAPH_OP(DynBroadcast, ngraph::op)
NGRAPH_OP(DynPad, ngraph::op) NGRAPH_OP(DynPad, ngraph::op)
NGRAPH_OP(DynReplaceSlice, ngraph::op)
NGRAPH_OP(DynReshape, ngraph::op) NGRAPH_OP(DynReshape, ngraph::op)
NGRAPH_OP(DynSlice, ngraph::op) NGRAPH_OP(DynSlice, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op) NGRAPH_OP(EmbeddingLookup, ngraph::op)
......
...@@ -84,7 +84,7 @@ void op::Pad::validate_and_infer_types() ...@@ -84,7 +84,7 @@ void op::Pad::validate_and_infer_types()
if (arg_shape[i].is_static()) if (arg_shape[i].is_static())
{ {
ptrdiff_t result_dim = ptrdiff_t result_dim =
m_padding_below[i] + static_cast<ptrdiff_t>(arg_shape[i]) + m_padding_above[i]; m_padding_below[i] + static_cast<int64_t>(arg_shape[i]) + m_padding_above[i];
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
result_dim >= 0, result_dim >= 0,
"Inferred result dimension at axis ", "Inferred result dimension at axis ",
......
...@@ -17,6 +17,14 @@ ...@@ -17,6 +17,14 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp" #include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -60,3 +68,34 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size ...@@ -60,3 +68,34 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
return afunc; return afunc;
} }
shared_ptr<Node> op::util::RNNCellBase::add(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::sub(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::mul(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::clip(const shared_ptr<Node>& data) const
{
if (m_clip == 0.f)
{
return data;
}
return make_shared<op::Clamp>(data, -m_clip, m_clip);
}
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp" #include "ngraph/op/util/activation_functions.hpp"
namespace ngraph namespace ngraph
...@@ -71,10 +73,48 @@ namespace ngraph ...@@ -71,10 +73,48 @@ namespace ngraph
/// \return The object representing activation function. /// \return The object representing activation function.
/// ///
ActivationFunction get_activation_function(std::size_t idx) const; ActivationFunction get_activation_function(std::size_t idx) const;
///
/// \brief Creates node with element-wise add operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise add operation.
///
static std::shared_ptr<Node> add(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise subtract operation.
///
static std::shared_ptr<Node> sub(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise multiply operation.
///
static std::shared_ptr<Node> mul(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
/// \param[in] data The input tensor for clipping.
///
/// \return Node with element-wise clip operation.
///
std::shared_ptr<Node> clip(const std::shared_ptr<Node>& data) const;
private: private:
std::size_t m_hidden_size = 0.f; const std::size_t m_hidden_size;
float m_clip = 0.f; const float m_clip;
const std::vector<std::string> m_activations; const std::vector<std::string> m_activations;
const std::vector<float> m_activation_alpha; const std::vector<float> m_activation_alpha;
const std::vector<float> m_activation_beta; const std::vector<float> m_activation_beta;
......
...@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst, ...@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
return success; return success;
} }
} }
bool PartialShape::all_non_negative() const
{
for (auto& d : m_dimensions)
{
if (d.is_static() && int64_t(d) < 0)
{
return false;
}
}
return true;
}
...@@ -164,6 +164,10 @@ namespace ngraph ...@@ -164,6 +164,10 @@ namespace ngraph
/// \throws std::invalid_argument If this PartialShape is dynamic. /// \throws std::invalid_argument If this PartialShape is dynamic.
Shape to_shape() const; Shape to_shape() const;
/// \brief Returns `true` if all static dimensions of the tensor are non-negative, else
/// `false`.
bool all_non_negative() const;
/// \brief Index operator for PartialShape. /// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected. /// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape. /// \return A reference to the `i`th Dimension of this shape.
......
...@@ -14,12 +14,17 @@ ...@@ -14,12 +14,17 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <numeric>
#include "dyn_elimination.hpp" #include "dyn_elimination.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
...@@ -33,7 +38,9 @@ pass::DynElimination::DynElimination() ...@@ -33,7 +38,9 @@ pass::DynElimination::DynElimination()
: GraphRewrite() : GraphRewrite()
{ {
construct_transpose(); construct_transpose();
construct_broadcast(); construct_dyn_broadcast();
construct_dyn_replace_slice();
construct_dyn_slice();
construct_dyn_reshape(); construct_dyn_reshape();
construct_range(); construct_range();
} }
...@@ -85,7 +92,7 @@ void pass::DynElimination::construct_transpose() ...@@ -85,7 +92,7 @@ void pass::DynElimination::construct_transpose()
add_matcher(transpose_matcher, transpose_callback, all_pass_property_off); add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
} }
void pass::DynElimination::construct_broadcast() void pass::DynElimination::construct_dyn_broadcast()
{ {
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3}); auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto shape_arg_label = auto shape_arg_label =
...@@ -367,7 +374,7 @@ static SlicePlan make_plan(const Shape& input_shape, ...@@ -367,7 +374,7 @@ static SlicePlan make_plan(const Shape& input_shape,
return p; return p;
} }
void pass::DynElimination::construct_dyn_reshape() void pass::DynElimination::construct_dyn_slice()
{ {
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3}); auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label = auto begins_arg_label =
...@@ -436,10 +443,139 @@ void pass::DynElimination::construct_dyn_reshape() ...@@ -436,10 +443,139 @@ void pass::DynElimination::construct_dyn_reshape()
}; };
auto dyn_slice_matcher = auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape"); make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynSlice");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off); add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
} }
void pass::DynElimination::construct_dyn_replace_slice()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto replacement_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto ends_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto strides_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_replace_slice_pat = make_shared<op::DynReplaceSlice>(data_arg_label,
replacement_arg_label,
begins_arg_label,
ends_arg_label,
strides_arg_label,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto dyn_replace_slice_callback = [data_arg_label,
replacement_arg_label,
begins_arg_label,
ends_arg_label,
strides_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto replacement_arg = pattern_map[replacement_arg_label];
auto begins_arg = static_pointer_cast<op::Constant>(pattern_map[begins_arg_label]);
auto ends_arg = static_pointer_cast<op::Constant>(pattern_map[ends_arg_label]);
auto strides_arg = static_pointer_cast<op::Constant>(pattern_map[strides_arg_label]);
auto dyn_replace_slice = static_pointer_cast<op::DynReplaceSlice>(m.get_match_root());
if (data_arg->get_output_partial_shape(0).is_dynamic() ||
replacement_arg->get_output_partial_shape(0).is_dynamic() ||
begins_arg->get_element_type() != element::i64 ||
ends_arg->get_element_type() != element::i64 ||
strides_arg->get_element_type() != element::i64)
{
return false;
}
SlicePlan p = make_plan(data_arg->get_output_shape(0),
begins_arg->get_vector<int64_t>(),
ends_arg->get_vector<int64_t>(),
strides_arg->get_vector<int64_t>(),
dyn_replace_slice->get_lower_bounds_mask(),
dyn_replace_slice->get_upper_bounds_mask(),
dyn_replace_slice->get_new_axis(),
dyn_replace_slice->get_shrink_axis(),
dyn_replace_slice->get_ellipsis_mask());
shared_ptr<Node> substitute_replacement_arg = replacement_arg;
if (!p.reverse_axes.empty())
{
substitute_replacement_arg =
make_shared<op::Reverse>(substitute_replacement_arg, p.reverse_axes);
}
if (p.reshape_in_shape != p.reshape_out_shape)
{
substitute_replacement_arg =
make_shared<op::Reshape>(substitute_replacement_arg,
ngraph::get_default_order(p.reshape_out_shape),
p.reshape_in_shape);
}
auto substitute_rsl =
make_shared<op::ReplaceSlice>(data_arg,
substitute_replacement_arg,
Coordinate(p.begins.begin(), p.begins.end()),
Coordinate(p.ends.begin(), p.ends.end()),
Strides(p.strides.begin(), p.strides.end()));
replace_node(m.get_match_root(), substitute_rsl);
return true;
};
auto dyn_replace_slice_matcher =
make_shared<pattern::Matcher>(dyn_replace_slice_pat, "DynElimination.DynReplaceShape");
add_matcher(dyn_replace_slice_matcher, dyn_replace_slice_callback, all_pass_property_off);
}
void pass::DynElimination::construct_dyn_reshape()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto shape_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(data_arg_label, shape_arg_label);
auto dyn_reshape_callback = [data_arg_label, shape_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto shape_arg = static_pointer_cast<op::Constant>(pattern_map[shape_arg_label]);
auto dyn_reshape_node = static_pointer_cast<op::DynReshape>(m.get_match_root());
// TODO(amprocte): Can't handle the case where data rank is dynamic even if we know the
// output shape, because static Reshape requries an axis permutation (here an identity) to
// be given. See if we can come up with a workaround.
if (data_arg->get_output_partial_shape(0).rank().is_dynamic())
{
return false;
}
if (dyn_reshape_node->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto& result_shape = dyn_reshape_node->get_output_shape(0);
AxisVector perm(size_t(data_arg->get_output_partial_shape(0).rank()));
std::iota(perm.begin(), perm.end(), 0);
auto replacement = std::make_shared<op::Reshape>(data_arg, perm, result_shape);
replace_node(dyn_reshape_node, replacement);
return true;
};
auto dyn_reshape_matcher =
make_shared<pattern::Matcher>(dyn_reshape, "DynElimination.DynReshape");
add_matcher(dyn_reshape_matcher, dyn_reshape_callback, all_pass_property_off);
}
template <typename T> template <typename T>
std::shared_ptr<op::Constant> std::shared_ptr<op::Constant>
make_range_replacement_integral(const element::Type& et, make_range_replacement_integral(const element::Type& et,
......
...@@ -30,7 +30,9 @@ namespace ngraph ...@@ -30,7 +30,9 @@ namespace ngraph
private: private:
void construct_transpose(); void construct_transpose();
void construct_broadcast(); void construct_dyn_broadcast();
void construct_dyn_replace_slice();
void construct_dyn_slice();
void construct_dyn_reshape(); void construct_dyn_reshape();
void construct_range(); void construct_range();
}; };
......
...@@ -103,17 +103,23 @@ shared_ptr<runtime::Executable> ...@@ -103,17 +103,23 @@ shared_ptr<runtime::Executable>
#endif #endif
shared_ptr<runtime::Executable> rc; shared_ptr<runtime::Executable> rc;
auto it = m_exec_map.find(func); // we will protect the access to map (m_exec_map) across multiple threads by creating a lock_gaurd
if (it != m_exec_map.end()) // m_exec_map_mutex will be released once the object `guard` goes out of scope
{ {
rc = it->second; std::lock_guard<std::mutex> guard(m_exec_map_mutex);
auto it = m_exec_map.find(func);
if (it != m_exec_map.end())
{
rc = it->second;
return rc;
}
} }
else rc = make_shared<CPU_Executable>(func, pass_config, performance_counters_enabled);
{ {
rc = make_shared<CPU_Executable>(func, pass_config, performance_counters_enabled); std::lock_guard<std::mutex> guard(m_exec_map_mutex);
m_exec_map.insert({func, rc}); m_exec_map.insert({func, rc});
return rc;
} }
return rc;
} }
runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func, runtime::cpu::CPU_Executable::CPU_Executable(shared_ptr<Function> func,
...@@ -156,6 +162,7 @@ bool runtime::cpu::CPU_Executable::call(const vector<shared_ptr<runtime::Tensor> ...@@ -156,6 +162,7 @@ bool runtime::cpu::CPU_Executable::call(const vector<shared_ptr<runtime::Tensor>
void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Executable> exec) void runtime::cpu::CPU_Backend::remove_compiled_function(shared_ptr<Executable> exec)
{ {
std::lock_guard<std::mutex> guard(m_exec_map_mutex);
for (auto it = m_exec_map.begin(); it != m_exec_map.end(); ++it) for (auto it = m_exec_map.begin(); it != m_exec_map.end(); ++it)
{ {
if (it->second == exec) if (it->second == exec)
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include "cpu_backend_visibility.h" #include "cpu_backend_visibility.h"
#include "ngraph/pass/pass_config.hpp" #include "ngraph/pass/pass_config.hpp"
...@@ -63,6 +64,9 @@ namespace ngraph ...@@ -63,6 +64,9 @@ namespace ngraph
bool is_supported_property(const Property prop) const override; bool is_supported_property(const Property prop) const override;
private: private:
// this mutex will be used to protect the addition and deletion
// of function to m_exec_map across multiple threads
std::mutex m_exec_map_mutex;
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Executable>> std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Executable>>
m_exec_map; m_exec_map;
}; };
......
...@@ -218,6 +218,7 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const ...@@ -218,6 +218,7 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
{ {
set<string> unsupported_ops = {"Quantize", set<string> unsupported_ops = {"Quantize",
"Dequantize", "Dequantize",
"DynReplaceSlice",
"DynReshape", "DynReshape",
"DynSlice", "DynSlice",
"ShapeOf", "ShapeOf",
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include "ngraph/op/experimental/batch_mat_mul.hpp" #include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp" #include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
...@@ -612,6 +613,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS) ...@@ -612,6 +613,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
return compiled_function->add_to_runtime(index, function_name, args, out); return compiled_function->add_to_runtime(index, function_name, args, out);
} }
std::string runtime::gpu::GPU_Emitter::emit_DynReplaceSlice(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_DynReshape(EMIT_ARGS) std::string runtime::gpu::GPU_Emitter::emit_DynReshape(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); throw unsupported_op("Unsupported op '" + node->description() + "'");
......
...@@ -87,11 +87,13 @@ ...@@ -87,11 +87,13 @@
#include "ngraph/op/fused/grn.hpp" #include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp" #include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp" #include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp" #include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp" #include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp" #include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp" #include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
...@@ -2059,6 +2061,7 @@ shared_ptr<runtime::Executable> ...@@ -2059,6 +2061,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::DepthToSpace: case OP_TYPEID::DepthToSpace:
case OP_TYPEID::DynBroadcast: case OP_TYPEID::DynBroadcast:
case OP_TYPEID::DynPad: case OP_TYPEID::DynPad:
case OP_TYPEID::DynReplaceSlice:
case OP_TYPEID::DynReshape: case OP_TYPEID::DynReshape:
case OP_TYPEID::DynSlice: case OP_TYPEID::DynSlice:
case OP_TYPEID::Elu: case OP_TYPEID::Elu:
...@@ -2070,6 +2073,7 @@ shared_ptr<runtime::Executable> ...@@ -2070,6 +2073,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
case OP_TYPEID::GRN: case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose: case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::GRUCell:
case OP_TYPEID::HardSigmoid: case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LeakyRelu: case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell: case OP_TYPEID::LSTMCell:
...@@ -2077,6 +2081,7 @@ shared_ptr<runtime::Executable> ...@@ -2077,6 +2081,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::Normalize: case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
case OP_TYPEID::Passthrough: case OP_TYPEID::Passthrough:
case OP_TYPEID::RNNCell:
case OP_TYPEID::QuantizedAvgPool: case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolution: case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedConvolutionBias: case OP_TYPEID::QuantizedConvolutionBias:
...@@ -2195,11 +2200,13 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node) ...@@ -2195,11 +2200,13 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::Gemm: case OP_TYPEID::Gemm:
case OP_TYPEID::GRN: case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose: case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::GRUCell:
case OP_TYPEID::LeakyRelu: case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell: case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
case OP_TYPEID::Normalize: case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScaleShift: case OP_TYPEID::ScaleShift:
case OP_TYPEID::ShuffleChannels: case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth: case OP_TYPEID::SpaceToDepth:
......
...@@ -18,6 +18,7 @@ replace_slice_matrix ...@@ -18,6 +18,7 @@ replace_slice_matrix
replace_slice_matrix_inplace replace_slice_matrix_inplace
replace_slice_scalar replace_slice_scalar
replace_slice_vector replace_slice_vector
dyn_replace_slice
shape_of_5d shape_of_5d
shape_of_matrix shape_of_matrix
shape_of_scalar shape_of_scalar
......
...@@ -1503,7 +1503,8 @@ private: ...@@ -1503,7 +1503,8 @@ private:
case OP_TYPEID::Transpose: case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad: case OP_TYPEID::DynPad:
case OP_TYPEID::Tile: case OP_TYPEID::Tile:
default: throw unsupported_op("Unsupported op '" + node.description() + "'"); case OP_TYPEID::DynReplaceSlice:
throw unsupported_op("Unsupported op '" + node.description() + "'");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
......
...@@ -259,6 +259,12 @@ backwards_softmax_underflow ...@@ -259,6 +259,12 @@ backwards_softmax_underflow
backwards_softmax_3d backwards_softmax_3d
batch_mat_mul_forward batch_mat_mul_forward
dot_matrix_2x0_0x2 dot_matrix_2x0_0x2
rnn_cell_no_bias
rnn_cell_bias_clip
rnn_cell_activation_function
gru_cell_bias_clip
gru_cell_linear_before_reset
gru_cell_activation_function
# dgkutnic ww24.5: these tests are to be triaged by the PlaidML team # dgkutnic ww24.5: these tests are to be triaged by the PlaidML team
# ww25.2: re-scrubbed this list of tests after fixing check_inputs # ww25.2: re-scrubbed this list of tests after fixing check_inputs
......
This diff is collapsed.
...@@ -142,8 +142,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -142,8 +142,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
ptrdiff_t data_padded_dilated_dim = -1; ptrdiff_t data_padded_dilated_dim = -1;
if (data_dim_static) if (data_dim_static)
{ {
data_padded_dilated_dim = (static_cast<ptrdiff_t>(data_dilation[i]) * data_padded_dilated_dim = (static_cast<int64_t>(data_dilation[i]) *
(static_cast<ptrdiff_t>(data_shape[i]) - 1)) + (static_cast<int64_t>(data_shape[i]) - 1)) +
1 + data_padding_below[i] + data_padding_above[i]; 1 + data_padding_below[i] + data_padding_above[i];
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
node, node,
...@@ -158,8 +158,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -158,8 +158,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
ptrdiff_t window_dilated_dim = -1; ptrdiff_t window_dilated_dim = -1;
if (window_dim_static) if (window_dim_static)
{ {
window_dilated_dim = static_cast<ptrdiff_t>(window_dilation[i]) * window_dilated_dim = static_cast<int64_t>(window_dilation[i]) *
(static_cast<ptrdiff_t>(window_shape[i]) - 1) + (static_cast<int64_t>(window_shape[i]) - 1) +
1; 1;
NODE_VALIDATION_CHECK(node, NODE_VALIDATION_CHECK(node,
...@@ -628,3 +628,257 @@ void ngraph::infer_auto_padding(const Shape& image_shape, ...@@ -628,3 +628,257 @@ void ngraph::infer_auto_padding(const Shape& image_shape,
padding_above.push_back(pad_type == op::PadType::SAME_UPPER ? padding_rhs : padding_lhs); padding_above.push_back(pad_type == op::PadType::SAME_UPPER ? padding_rhs : padding_lhs);
} }
} }
PartialShape ngraph::infer_slice_shape(const Node* node,
const PartialShape& input_shape,
const std::vector<int64_t>& lb,
const std::vector<int64_t>& ub,
const std::vector<int64_t>& str,
const AxisSet& lb_mask,
const AxisSet& ub_mask,
const AxisSet& new_axis,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
{
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(node,
lb.size() == ub.size(),
"Lower bounds and Upper bounds needs to have same number of values");
}
if (lb.size() && str.size())
{
NODE_VALIDATION_CHECK(node,
lb.size() == str.size(),
"Lower bounds and strides needs to have same number of values");
}
if (ub.size() && str.size())
{
NODE_VALIDATION_CHECK(node,
ub.size() == str.size(),
"Upper bounds and strides needs to have same number of values");
}
if (input_shape.rank().is_dynamic())
{
return PartialShape::dynamic();
}
int max_dims = size_t(input_shape.rank()) + new_axis.size();
int bounds_size =
lb.size() ? lb.size() : (ub.size() ? ub.size() : (str.size() ? str.size() : 0));
int ellipsis_pos1 = ellipsis_mask.size() ? *ellipsis_mask.begin() : max_dims;
int ellipsis_pos2 = max_dims;
bounds_size -= ellipsis_pos1;
if (bounds_size > 0 && (max_dims - bounds_size) > ellipsis_pos1)
{
ellipsis_pos2 = max_dims - bounds_size;
}
std::vector<Dimension> begin_dms(max_dims, 0);
std::vector<Dimension> end_dms(max_dims, -1);
std::vector<Dimension> stride_dms(max_dims, 1);
std::vector<Dimension> out_dims;
int j = 0;
int k = 0;
int bj = 0;
int ej = 0;
int sj = 0;
for (int i = 0; i < max_dims; i++)
{
if (i >= ellipsis_pos1 && i < ellipsis_pos2)
{
if (new_axis.find(i) == new_axis.end())
{
if (end_dms[i].is_static() && int64_t(end_dms[i]) < 0)
{
end_dms[i] = input_shape[j++] + end_dms[i];
}
}
else
{
end_dms[i] = begin_dms[i];
}
if (end_dms[i].is_dynamic() || begin_dms[i].is_dynamic() || stride_dms[i].is_dynamic())
{
out_dims.push_back(Dimension::dynamic());
}
else
{
out_dims.push_back(static_cast<int64_t>(
ceil(static_cast<float>(abs(int64_t(end_dms[i]) - int64_t(begin_dms[i])) + 1) /
static_cast<float>(abs(int64_t(stride_dms[i]))))));
}
k = ellipsis_pos1;
continue;
}
stride_dms[i] = (str.size() > sj && str[sj] != 0) ? str[sj++] : 1;
// Use lower_bounds if mask is not set
if (lb_mask.find(j) == lb_mask.end())
{
if (lb.size() > bj)
{
begin_dms[i] = lb[bj];
}
else if (stride_dms[i].is_dynamic())
{
begin_dms[i] = Dimension::dynamic();
}
else if (int64_t(stride_dms[i]) > 0)
{
begin_dms[i] = 0;
}
else
{
begin_dms[i] = -1;
}
}
else if (stride_dms[i].is_dynamic())
{
begin_dms[i] = Dimension::dynamic();
}
else if (int64_t(stride_dms[i]) > 0)
{
begin_dms[i] = 0;
}
else
{
begin_dms[i] = -1;
}
bj++;
if (begin_dms[i].is_static() && int64_t(begin_dms[i]) < 0)
{
begin_dms[i] = input_shape[j] + begin_dms[i];
}
// Clipping 'begin'
if (begin_dms[i].is_static())
{
if (int64_t(begin_dms[i]) < 0)
{
begin_dms[i] = 0;
}
else if (input_shape[j].is_dynamic())
{
begin_dms[i] = Dimension::dynamic();
}
else if (int64_t(begin_dms[i]) >= int64_t(input_shape[j]))
{
begin_dms[i] = input_shape[j] - 1;
}
}
// Use upper_bounds if mask is not set
if (ub_mask.find(j) == ub_mask.end())
{
Dimension end_dms_tmp;
if (ub.size() <= ej)
{
end_dms_tmp = end_dms[i];
}
else if (stride_dms[i].is_dynamic())
{
end_dms_tmp = Dimension::dynamic();
}
else if (int64_t(stride_dms[i]) > 0)
{
end_dms_tmp = ub[ej] - 1;
}
else
{
end_dms_tmp = ub[ej] + 1;
}
if (ub.size() > ej)
{
end_dms[i] = end_dms_tmp;
}
else if (stride_dms[i].is_dynamic())
{
end_dms[i] = Dimension::dynamic();
}
else if (int64_t(stride_dms[i]) > 0)
{
end_dms[i] = -1;
}
else
{
end_dms[i] = 0;
}
}
else
{
if (stride_dms[i].is_dynamic())
{
end_dms[i] = Dimension::dynamic();
}
else if (int64_t(stride_dms[i]) > 0)
{
end_dms[i] = -1;
}
else
{
end_dms[i] = 0;
}
}
ej++;
if (end_dms[i].is_static() && int64_t(end_dms[i]) < 0)
{
end_dms[i] = input_shape[j] + end_dms[i];
}
// Clipping 'end'
if (end_dms[i].is_static())
{
if (int64_t(end_dms[i]) < 0)
{
end_dms[i] = 0;
}
else if (input_shape[j].is_dynamic())
{
end_dms[i] = Dimension::dynamic();
}
else if (int64_t(end_dms[i]) >= int64_t(input_shape[j]))
{
end_dms[i] = input_shape[j] - 1;
}
}
if (new_axis.find(i) == new_axis.end())
{
j++;
}
else
{
end_dms[i] = 0;
}
if (shrink_axis.find(k) != shrink_axis.end())
{
end_dms[i] = begin_dms[i];
}
else if (end_dms[i].is_dynamic() || begin_dms[i].is_dynamic() || stride_dms[i].is_dynamic())
{
out_dims.push_back(Dimension::dynamic());
}
else
{
out_dims.push_back(static_cast<int64_t>(
ceil(static_cast<float>(abs(int64_t(end_dms[i]) - int64_t(begin_dms[i])) + 1) /
static_cast<float>(abs(int64_t(stride_dms[i]))))));
}
k++;
}
return out_dims;
}
...@@ -94,4 +94,15 @@ namespace ngraph ...@@ -94,4 +94,15 @@ namespace ngraph
const op::PadType pad_type, const op::PadType pad_type,
CoordinateDiff& padding_above, CoordinateDiff& padding_above,
CoordinateDiff& padding_below); CoordinateDiff& padding_below);
PartialShape infer_slice_shape(const Node* node,
const PartialShape& input_shape,
const std::vector<int64_t>& lb,
const std::vector<int64_t>& ub,
const std::vector<int64_t>& str,
const AxisSet& lb_mask,
const AxisSet& ub_mask,
const AxisSet& new_axis,
const AxisSet& shrink_mask,
const AxisSet& ellipsis_mask);
} }
...@@ -167,6 +167,7 @@ set(MULTI_TEST_SRC ...@@ -167,6 +167,7 @@ set(MULTI_TEST_SRC
backend_test.in.cpp backend_test.in.cpp
backend_unary_elementwise.in.cpp backend_unary_elementwise.in.cpp
convolution_test.in.cpp convolution_test.in.cpp
dyn_replace_slice_test.in.cpp
dyn_slice_test.in.cpp dyn_slice_test.in.cpp
dynamic.in.cpp dynamic.in.cpp
) )
......
This diff is collapsed.
...@@ -132,6 +132,80 @@ TEST(dyn_elimination, slice) ...@@ -132,6 +132,80 @@ TEST(dyn_elimination, slice)
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2})); ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2}));
} }
TEST(dyn_elimination, replace_slice)
{
// input has shape [2,4,6,8,2,2,2]
// slice in numpy syntax is [0:,:4,2:6:2,7:3:-2,np.newaxis,...,1]
// slice shape should be [2,4,2,2,1,2,2] (so sayeth numpy!)
Shape shape_in{2, 4, 6, 8, 2, 2, 2};
Shape shape_slice{2, 4, 2, 2, 1, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, shape_in);
auto replacement = make_shared<op::Parameter>(element::f32, shape_slice);
auto constant_lb =
make_shared<op::Constant>(element::i64, Shape{7}, vector<int64_t>{0, 3, 2, 7, 0, 0, 1});
auto constant_ub =
make_shared<op::Constant>(element::i64, Shape{7}, vector<int64_t>{0, 4, 6, 3, 0, 0, 0});
auto constant_strides =
make_shared<op::Constant>(element::i64, Shape{7}, vector<int64_t>{1, 1, 2, -2, 0, 0, 0});
AxisSet lower_bounds_mask{1};
AxisSet upper_bounds_mask{0};
AxisSet new_axis_mask{4};
AxisSet shrink_mask{6};
AxisSet ellipsis_mask{5};
auto rsl = make_shared<op::DynReplaceSlice>(input,
replacement,
constant_lb,
constant_ub,
constant_strides,
lower_bounds_mask,
upper_bounds_mask,
new_axis_mask,
shrink_mask,
ellipsis_mask);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_EQ(rsl->get_shape(), (Shape{2, 4, 6, 8, 2, 2, 2}));
auto f = make_shared<Function>(rsl, ParameterVector{input, replacement});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReplaceSlice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::ReplaceSlice>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Reverse>(f), 1);
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 6, 8, 2, 2, 2}));
}
TEST(dyn_elimination, reshape)
{
auto input_arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto shape_arg = make_shared<op::Constant>(element::i64, Shape{3}, vector<int64_t>{0, 6, -1});
auto r = make_shared<op::DynReshape>(input_arg, shape_arg, true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{2, 6, 32}));
auto f = make_shared<Function>(r, ParameterVector{input_arg});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 6, 32}));
}
TEST(dyn_elimination, range) TEST(dyn_elimination, range)
{ {
auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0}); auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0});
......
This diff is collapsed.
...@@ -365,3 +365,94 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, range) ...@@ -365,3 +365,94 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, range)
ASSERT_EQ(results, test.expected_result); ASSERT_EQ(results, test.expected_result);
} }
} }
NGRAPH_TEST(dynamic_${BACKEND_NAME}, reshape)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto build_graph = [&backend](bool zero_flag) {
// Create a graph for f(x,shape) = DynReshape(x,shape,zero_flag=zero_flag).
auto x = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto dyn_reshape = make_shared<op::DynReshape>(x, shape, zero_flag);
EXPECT_TRUE(dyn_reshape->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
auto f = make_shared<Function>(NodeVector{dyn_reshape}, ParameterVector{x, shape});
auto ex = backend->compile(f);
return ex;
};
auto t_r = backend->create_dynamic_tensor(element::i32, PartialShape::dynamic());
auto ex_flag_off = build_graph(false);
auto ex_flag_on = build_graph(true);
std::vector<std::tuple<bool, Shape, std::vector<int32_t>, std::vector<int64_t>, Shape>> tests;
tests.emplace_back(make_tuple(
false, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{6}, Shape{6}));
tests.emplace_back(make_tuple(
true, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{6}, Shape{6}));
tests.emplace_back(make_tuple(
false, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{-1}, Shape{6}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{2, -1},
Shape{2, 3}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, -1},
Shape{3, 2}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, 2, -1},
Shape{3, 2, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, 2, -1},
Shape{3, 2, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{0, 0, -1},
Shape{2, 3, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{2, 0, -1},
Shape{2, 3, 1}));
tests.emplace_back(make_tuple(
true, Shape{0, 3, 4}, vector<int32_t>{}, vector<int64_t>{3, -1, 2}, Shape{3, 0, 2}));
for (auto& test : tests)
{
bool zero_flag = get<0>(test);
const Shape& in_shape = get<1>(test);
const std::vector<int32_t>& data = get<2>(test);
const std::vector<int64_t>& dims = get<3>(test);
const Shape& out_shape = get<4>(test);
auto t_x = backend->create_tensor(element::i32, in_shape);
auto t_shape = backend->create_tensor(element::i64, Shape{dims.size()});
copy_data(t_x, data);
copy_data(t_shape, dims);
auto ex = zero_flag ? ex_flag_on : ex_flag_off;
ex->call_with_validate({t_r}, {t_x, t_shape});
ASSERT_EQ(t_r->get_element_type(), element::i32);
ASSERT_EQ(t_r->get_shape(), out_shape);
auto results = read_vector<int32_t>(t_r);
ASSERT_EQ(results, data);
}
}
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: ""
output: "Y_h"
op_type: "LSTM"
attribute {
name: "clip"
f: 9999.0
type: FLOAT
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 3
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 32
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 12
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 32
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 7
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <fstream> #include <fstream>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <numeric>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
...@@ -203,3 +204,48 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation) ...@@ -203,3 +204,48 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_hardsigmoid_activation)
test_case.set_tolerance(6); test_case.set_tolerance(6);
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_large_batch_no_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_large_batch_no_clip.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
std::size_t seq_length = 2;
std::size_t batch_size = 32;
std::size_t input_size = 1;
std::size_t hidden_size = 3;
std::vector<float> in_X(seq_length * batch_size * input_size);
std::iota(std::begin(in_X), std::end(in_X), 1.f);
std::vector<float> in_R(4 * hidden_size * hidden_size, 0.1f);
// X
test_case.add_input<float>(in_X);
// W
test_case.add_input<float>(
{0.1f, 0.2f, 0.3f, 0.4f, 1.f, 2.f, 3.f, 4.f, 10.f, 11.f, 12.f, 13.f});
// R
test_case.add_input<float>(in_R);
// Y_h_data
test_case.add_expected_output<float>(
Shape{1, batch_size, hidden_size},
{0.90387899f, 0.9135572f, 0.91772245f, 0.90897038f, 0.92132433f, 0.92825467f, 0.91365823f,
0.92815113f, 0.93676105f, 0.91799162f, 0.93406357f, 0.94344562f, 0.92199681f, 0.93912057f,
0.94859476f, 0.92569357f, 0.94340185f, 0.95250664f, 0.92909964f, 0.94699686f, 0.95545127f,
0.93223207f, 0.94999634f, 0.95765468f, 0.93510761f, 0.9524867f, 0.95929726f, 0.93774272f,
0.9545467f, 0.96051891f, 0.9401536f, 0.95624603f, 0.96142619f, 0.94235605f, 0.95764499f,
0.96209939f, 0.94436539f, 0.95879495f, 0.96259862f, 0.94619635f, 0.95973921f, 0.96296872f,
0.94786299f, 0.96051397f, 0.96324302f, 0.94937864f, 0.96114929f, 0.96344629f, 0.95075587f,
0.96167006f, 0.96359692f, 0.95200645f, 0.96209679f, 0.96370852f, 0.95314133f, 0.9624464f,
0.9637912f, 0.95417069f, 0.96273278f, 0.96385246f, 0.95510395f, 0.96296733f, 0.96389785f,
0.95594975f, 0.96315942f, 0.96393147f, 0.95671607f, 0.96331673f, 0.96395638f, 0.9574102f,
0.96344554f, 0.96397483f, 0.9580388f, 0.96355102f, 0.9639885f, 0.95860795f, 0.96363739f,
0.96399863f, 0.95912322f, 0.96370811f, 0.96400613f, 0.95958963f, 0.96376601f, 0.96401169f,
0.96001179f, 0.96381342f, 0.96401581f, 0.96039386f, 0.96385224f, 0.96401886f, 0.96073964f,
0.96388402f, 0.96402112f, 0.96105254f, 0.96391004f, 0.96402279f});
test_case.run();
}
This diff is collapsed.
This diff is collapsed.
#!/bin/bash
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
declare THIS_SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
python ${THIS_SCRIPT_DIR}/ref_generators/generate_dyn_replace_slice_ref.py ${THIS_SCRIPT_DIR}/dyn_replace_slice_test.in.cpp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment