Commit e5f489a2 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[Fused] Add DepthToSpace and SpaceToDepth fused ops (#2811)

* Refactor get_default_axis_vector to use std:: functions

* Move get_default_axis_vector to ngraph::op

* Move reorder_axes to ngraph::op::util

* Move reshape helper to ngraph::op::util

* Move DepthToSpace to fused ops

* Add DepthToSpace docstrings

* Move SpaceToDepth to fused ops

* Remove redundant ngraph::op::util::get_default_axis_vector function

* Add ops to serializer

* Change block_size to size_t

* Add fused ops tests

* Add type prop tests

* Add ops to list of ops unsupported on iGPU

* Disable tests in iGPU manifest
parent 8491030d
...@@ -259,10 +259,14 @@ set (SRC ...@@ -259,10 +259,14 @@ set (SRC
op/topk.hpp op/topk.hpp
op/fused/conv_fused.cpp op/fused/conv_fused.cpp
op/fused/conv_fused.hpp op/fused/conv_fused.hpp
op/fused/depth_to_space.cpp
op/fused/depth_to_space.hpp
op/fused/elu.cpp op/fused/elu.cpp
op/fused/elu.hpp op/fused/elu.hpp
op/fused/prelu.cpp op/fused/prelu.cpp
op/fused/prelu.hpp op/fused/prelu.hpp
op/fused/space_to_depth.cpp
op/fused/space_to_depth.hpp
op/util/arithmetic_reduction.cpp op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp op/util/binary_elementwise_arithmetic.cpp
...@@ -279,6 +283,8 @@ set (SRC ...@@ -279,6 +283,8 @@ set (SRC
op/util/index_reduction.hpp op/util/index_reduction.hpp
op/util/logical_reduction.cpp op/util/logical_reduction.cpp
op/util/logical_reduction.hpp op/util/logical_reduction.hpp
op/util/reshape.cpp
op/util/reshape.hpp
op/util/unary_elementwise_arithmetic.cpp op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp op/util/unary_elementwise_arithmetic.hpp
parameter_vector.hpp parameter_vector.hpp
......
...@@ -14,15 +14,8 @@ ...@@ -14,15 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstddef>
#include <cstdint>
#include <memory>
#include "depth_to_space.hpp" #include "depth_to_space.hpp"
#include "exceptions.hpp" #include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
#include "utils/reshape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,43 +28,8 @@ namespace ngraph ...@@ -35,43 +28,8 @@ namespace ngraph
NodeVector depth_to_space(const Node& node) NodeVector depth_to_space(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
const Shape& data_shape = data->get_shape(); std::size_t block_size = node.get_attribute_value<std::int64_t>("blocksize");
return NodeVector{std::make_shared<ngraph::op::DepthToSpace>(data, block_size)};
std::int64_t block_size{node.get_attribute_value<std::int64_t>("blocksize")};
// Set default values to each dimension to be able to work with both 3D or 4D data.
std::size_t n{1}, c{1}, h{1}, w{1};
ASSERT_VALID_ARGUMENT(node, (data_shape.size() == 3 || data_shape.size() == 4))
<< "The provided tensor shape: " << data_shape << " is not supported.";
// Assume NCHW data layout
if (data_shape.size() == 4)
{
n = data_shape.at(0);
c = data_shape.at(1);
h = data_shape.at(2);
w = data_shape.at(3);
}
// Without batch.
else if (data_shape.size() == 3)
{
c = data_shape.at(0);
h = data_shape.at(1);
w = data_shape.at(2);
}
ASSERT_VALID_ARGUMENT(node,
(c % (block_size * block_size) == 0 && block_size > 0))
<< "The depth axis size must be a multiple of squared block_size attribute "
"value";
std::size_t bs = static_cast<std::size_t>(block_size);
std::size_t c_flat = c / (bs * bs);
// First we have to disperse the data from depth channel, then rearrange them
// so as appropriate chunks of data where close to their destination place.
// Finally squeeze data from respective dimensions.
std::shared_ptr<ngraph::Node> flat_node =
reshape::reshape(data, ngraph::Shape{n, bs, bs, c_flat, h, w});
flat_node = reshape::reorder_axes(flat_node, {0, 3, 4, 1, 5, 2});
return {reshape::reshape(flat_node, ngraph::Shape{n, c_flat, h * bs, w * bs})};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/util.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/norm.hpp" #include "utils/norm.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -62,7 +64,7 @@ namespace ngraph ...@@ -62,7 +64,7 @@ namespace ngraph
output_shape.at(0) = orig_shape.at(0); output_shape.at(0) = orig_shape.at(0);
slice = std::make_shared<ngraph::op::Reshape>( slice = std::make_shared<ngraph::op::Reshape>(
slice, slice,
reshape::get_default_axis_vector(slice->get_shape().size()), ngraph::get_default_order(slice->get_shape().size()),
output_shape); output_shape);
} }
......
...@@ -29,7 +29,9 @@ ...@@ -29,7 +29,9 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
/// \brief Slice the sub matrix from the input tensor. /// \brief Slice the sub matrix from the input tensor.
...@@ -200,7 +202,7 @@ namespace ngraph ...@@ -200,7 +202,7 @@ namespace ngraph
std::begin(left_shape), std::begin(left_shape),
std::next(std::begin(left_shape), left_shape.size() - 2)); std::next(std::begin(left_shape), left_shape.size() - 2));
return {std::make_shared<ngraph::op::Reshape>( return {std::make_shared<ngraph::op::Reshape>(
result, reshape::get_default_axis_vector(shape.size()), result_shape)}; result, ngraph::get_default_order(shape.size()), result_shape)};
} }
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "reshape.hpp" #include "reshape.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -63,9 +64,7 @@ namespace ngraph ...@@ -63,9 +64,7 @@ namespace ngraph
output_shape = output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape); reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>( return {std::make_shared<ngraph::op::Reshape>(
data, data, ngraph::get_default_order(data_shape.size()), Shape{output_shape})};
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -14,15 +14,8 @@ ...@@ -14,15 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstddef> #include "ngraph/op/fused/space_to_depth.hpp"
#include <cstdint>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
#include "space_to_depth.hpp" #include "space_to_depth.hpp"
#include "utils/reshape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,45 +28,8 @@ namespace ngraph ...@@ -35,45 +28,8 @@ namespace ngraph
NodeVector space_to_depth(const Node& node) NodeVector space_to_depth(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
const Shape& data_shape = data->get_shape(); std::size_t block_size = node.get_attribute_value<std::int64_t>("blocksize");
return NodeVector{std::make_shared<ngraph::op::SpaceToDepth>(data, block_size)};
std::int64_t block_size{node.get_attribute_value<std::int64_t>("blocksize")};
// Set default values to each dimension to be able to work with both 3D or 4D data.
std::size_t n{1}, c{1}, h{1}, w{1};
ASSERT_VALID_ARGUMENT(node, (data_shape.size() == 3 || data_shape.size() == 4))
<< "The provided tensor shape: " << data_shape << " is not supported.";
// Assume NCHW data layout
if (data_shape.size() == 4)
{
n = data_shape.at(0);
c = data_shape.at(1);
h = data_shape.at(2);
w = data_shape.at(3);
}
// Without batch.
else if (data_shape.size() == 3)
{
c = data_shape.at(0);
h = data_shape.at(1);
w = data_shape.at(2);
}
ASSERT_VALID_ARGUMENT(
node, (h % block_size == 0 && w % block_size == 0 && block_size > 0))
<< "The width and height axes size must be a multiple of squared block_size"
" attribute value";
std::size_t bs = static_cast<std::size_t>(block_size);
std::size_t w_flat = w / bs;
std::size_t h_flat = h / bs;
std::size_t c_high = c * bs * bs;
// First we have to disperse the data from height and width channels, then
// rearrange them so as appropriate chunks of data where close to their
// destination place. Finally squeeze data from respective dimensions.
std::shared_ptr<ngraph::Node> flat_node =
reshape::reshape(data, ngraph::Shape{n, c, h_flat, bs, w_flat, bs});
flat_node = reshape::reorder_axes(flat_node, {0, 3, 5, 1, 2, 4});
return {reshape::reshape(flat_node, ngraph::Shape{n, c_high, h_flat, w_flat})};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -24,7 +24,9 @@ ...@@ -24,7 +24,9 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "squeeze.hpp" #include "squeeze.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -42,7 +44,7 @@ namespace ngraph ...@@ -42,7 +44,7 @@ namespace ngraph
auto data = inputs.at(0); auto data = inputs.at(0);
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<std::size_t>>("axes", {}); auto axes = node.get_attribute_value<std::vector<std::size_t>>("axes", {});
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{ngraph::get_default_order(data_shape.size())};
// Prepare set of unique axes marked to be removed from input data. // Prepare set of unique axes marked to be removed from input data.
if (axes.empty()) if (axes.empty())
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "transpose.hpp" #include "transpose.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -36,8 +37,9 @@ namespace ngraph ...@@ -36,8 +37,9 @@ namespace ngraph
auto permute_axes = auto permute_axes =
node.get_attribute_value<std::vector<std::size_t>>("perm", {}); node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data) return {(permute_axes.empty())
: reshape::reorder_axes(data, permute_axes)}; ? reshape::transpose(data)
: ngraph::op::util::reorder_axes(data, permute_axes)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/util.hpp"
#include "unsqueeze.hpp" #include "unsqueeze.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -47,7 +49,7 @@ namespace ngraph ...@@ -47,7 +49,7 @@ namespace ngraph
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>()); std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{ngraph::get_default_order(data_shape.size())};
for (auto axis : axes) for (auto axis : axes)
{ {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "reduction.hpp" #include "reduction.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
...@@ -72,7 +73,7 @@ namespace ngraph ...@@ -72,7 +73,7 @@ namespace ngraph
} }
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
op_node, op_node,
reshape::get_default_axis_vector(op_node->get_shape().size()), ngraph::get_default_order(op_node->get_shape().size()),
Shape{output_shape}); Shape{output_shape});
} }
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
namespace ngraph namespace ngraph
...@@ -77,7 +79,7 @@ namespace ngraph ...@@ -77,7 +79,7 @@ namespace ngraph
output_shape.at(axis) = 1; output_shape.at(axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>( auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node, convert_node,
reshape::get_default_axis_vector(op_node->get_shape().size()), ngraph::get_default_order(op_node->get_shape().size()),
Shape{output_shape}); Shape{output_shape});
// WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove // WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -80,16 +81,10 @@ namespace ngraph ...@@ -80,16 +81,10 @@ namespace ngraph
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
node, node,
get_default_axis_vector(data_shape.size()), ngraph::get_default_order(data_shape.size()),
Shape{first_dim_size, last_dim_size}); Shape{first_dim_size, last_dim_size});
} }
AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value)
{
return AxisVector{
common::get_monotonic_range<std::size_t>(data_shape_size, start_value)};
}
std::vector<std::size_t> infer_dimensions(const std::string& node_name, std::vector<std::size_t> infer_dimensions(const std::string& node_name,
const std::vector<std::size_t>& input_shape, const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape) const std::vector<std::size_t>& output_shape)
...@@ -145,33 +140,12 @@ namespace ngraph ...@@ -145,33 +140,12 @@ namespace ngraph
return inferred_dims; return inferred_dims;
} }
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes_order = {})
{
Shape out_shape = node->get_shape();
if (axes_order.empty())
{
axes_order.resize(out_shape.size());
std::iota(std::begin(axes_order), std::end(axes_order), 0);
}
else
{
for (std::size_t i = 0; i < axes_order.size(); ++i)
{
out_shape[i] = node->get_shape().at(axes_order.at(i));
}
}
auto axis_vector = AxisVector{std::begin(axes_order), std::end(axes_order)};
return std::make_shared<ngraph::op::Reshape>(node, axis_vector, out_shape);
}
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node) std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node)
{ {
std::vector<size_t> axes_order(node->get_shape().size()); std::vector<size_t> axes_order(node->get_shape().size());
std::iota(std::begin(axes_order), std::end(axes_order), 0); std::iota(std::begin(axes_order), std::end(axes_order), 0);
std::reverse(std::begin(axes_order), std::end(axes_order)); std::reverse(std::begin(axes_order), std::end(axes_order));
return reorder_axes(node, axes_order); return ngraph::op::util::reorder_axes(node, axes_order);
} }
std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> squeeze(const std::shared_ptr<ngraph::Node>& node,
...@@ -195,7 +169,7 @@ namespace ngraph ...@@ -195,7 +169,7 @@ namespace ngraph
output_shape.push_back(axis); output_shape.push_back(axis);
} }
} }
return reshape(node, output_shape); return ngraph::op::util::reshape(node, output_shape);
} }
std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node,
...@@ -213,15 +187,7 @@ namespace ngraph ...@@ -213,15 +187,7 @@ namespace ngraph
output_shape.insert(std::end(output_shape), output_shape.insert(std::end(output_shape),
std::next(std::begin(shape), end_axis + 1), std::next(std::begin(shape), end_axis + 1),
std::end(shape)); std::end(shape));
return reshape(node, output_shape); return ngraph::op::util::reshape(node, output_shape);
}
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const AxisVector& axis_order,
const Shape& shape)
{
return std::make_shared<ngraph::op::Reshape>(
node, get_default_axis_vector(node->get_shape().size()), shape);
} }
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
...@@ -233,7 +199,7 @@ namespace ngraph ...@@ -233,7 +199,7 @@ namespace ngraph
std::advance(empty_axis_it, axis); std::advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1); output_shape.insert(empty_axis_it, 1);
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape); node, ngraph::get_default_order(node->get_shape().size()), output_shape);
} }
NodeVector split(const std::shared_ptr<ngraph::Node>& node, NodeVector split(const std::shared_ptr<ngraph::Node>& node,
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -41,17 +42,6 @@ namespace ngraph ...@@ -41,17 +42,6 @@ namespace ngraph
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis); int axis);
/// \brief Gets the AxisVector filled with monotonic increasing
/// sequence.
///
/// \param[in] data_shape_size The data shape size.
/// \param[in] start_value The start_value for sequence. Default equals 0.
///
/// \return The filled AxisVector.
///
AxisVector get_default_axis_vector(std::size_t data_shape_size,
std::size_t start_value = 0);
/// \brief Infer `output_shape` dimension values. /// \brief Infer `output_shape` dimension values.
/// ///
/// \par Inferention rules /// \par Inferention rules
...@@ -70,15 +60,6 @@ namespace ngraph ...@@ -70,15 +60,6 @@ namespace ngraph
const std::vector<std::size_t>& input_shape, const std::vector<std::size_t>& input_shape,
const std::vector<std::size_t>& output_shape); const std::vector<std::size_t>& output_shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
///
/// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes_order);
/// \brief Return transposed tensor (with axes in reversed order). /// \brief Return transposed tensor (with axes in reversed order).
/// ///
/// \param node Input tensor we want to transpose /// \param node Input tensor we want to transpose
...@@ -110,23 +91,6 @@ namespace ngraph ...@@ -110,23 +91,6 @@ namespace ngraph
const std::size_t start_axis, const std::size_t start_axis,
const std::size_t end_axis); const std::size_t end_axis);
/// \brief Change shape of input tensor.
///
/// \param[in] node The node which shape will be changed.
/// \param[in] shape The new shape for input tensor.
///
/// \return The node representing reshaped input tensor.
///
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const AxisVector& axis_order,
const Shape& shape);
inline std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const Shape& shape)
{
return reshape(node, get_default_axis_vector(node->get_shape().size()), shape);
}
/// \brief Expands node tensor shape with empty axis at /// \brief Expands node tensor shape with empty axis at
/// specified position. /// specified position.
/// ///
......
...@@ -95,8 +95,10 @@ ...@@ -95,8 +95,10 @@
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <cstdint>
#include <memory>
#include "depth_to_space.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
op::DepthToSpace::DepthToSpace(const shared_ptr<Node>& data, const size_t block_size)
: FusedOp("DepthToSpace", {data})
, m_blocksize(block_size)
{
constructor_validate_and_infer_types();
}
NodeVector op::DepthToSpace::decompose_op() const
{
auto data = get_argument(0);
const Shape& data_shape = data->get_shape();
// Set default values to each dimension to be able to work with both 3D or 4D data.
size_t n{1}, c{1}, h{1}, w{1};
NGRAPH_CHECK((data_shape.size() == 3 || data_shape.size() == 4),
"The provided tensor shape: ",
data_shape,
" is not supported.");
// Assume NCHW data layout
if (data_shape.size() == 4)
{
n = data_shape.at(0);
c = data_shape.at(1);
h = data_shape.at(2);
w = data_shape.at(3);
}
// Without batch.
else if (data_shape.size() == 3)
{
c = data_shape.at(0);
h = data_shape.at(1);
w = data_shape.at(2);
}
NGRAPH_CHECK((c % (m_blocksize * m_blocksize) == 0 && m_blocksize > 0),
"SpaceToDepth: The depth axis size must be a multiple of ",
"squared block_size attribute value.");
auto bs = static_cast<size_t>(m_blocksize);
size_t c_flat = c / (bs * bs);
// First we have to disperse the data from depth channel, then rearrange them
// so as appropriate chunks of data where close to their destination place.
// Finally squeeze data from respective dimensions.
shared_ptr<Node> flat_node = op::util::reshape(data, Shape{n, bs, bs, c_flat, h, w});
flat_node = op::util::reorder_axes(flat_node, {0, 3, 4, 1, 5, 2});
return NodeVector{op::util::reshape(flat_node, Shape{n, c_flat, h * bs, w * bs})};
}
shared_ptr<Node> op::DepthToSpace::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<DepthToSpace>(new_args.at(0), m_blocksize);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief DepthToSpace permutes data from the depth dimension of the input blob into spatial dimensions.
///
/// \note Values from the depth dimension (assuming NCHW layout) are moved in
/// spatial blocks to the height and width dimensions.
///
/// Output node produces a tensor with shape:
/// [N, C/(blocksize * blocksize), H * blocksize, W * blocksize]
class DepthToSpace : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs a DepthToSpace operation.
///
/// \param data - Node producing the input tensor
/// \param block_size - the size of the block of values to be moved
DepthToSpace(const std::shared_ptr<ngraph::Node>& data, std::size_t block_size);
std::size_t get_block_size() const { return m_blocksize; }
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
std::size_t m_blocksize;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <cstdint>
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp"
#include "space_to_depth.hpp"
using namespace std;
using namespace ngraph;
op::SpaceToDepth::SpaceToDepth(const shared_ptr<Node>& data, const size_t block_size)
: FusedOp("SpaceToDepth", {data})
, m_blocksize(block_size)
{
constructor_validate_and_infer_types();
}
NodeVector op::SpaceToDepth::decompose_op() const
{
auto data = get_argument(0);
const Shape& data_shape = data->get_shape();
// Set default values to each dimension to be able to work with both 3D or 4D data.
size_t n{1}, c{1}, h{1}, w{1};
NGRAPH_CHECK((data_shape.size() == 3 || data_shape.size() == 4),
"The provided tensor shape: ",
data_shape,
" is not supported.");
// Assume NCHW data layout
if (data_shape.size() == 4)
{
n = data_shape.at(0);
c = data_shape.at(1);
h = data_shape.at(2);
w = data_shape.at(3);
}
// Without batch.
else if (data_shape.size() == 3)
{
c = data_shape.at(0);
h = data_shape.at(1);
w = data_shape.at(2);
}
NGRAPH_CHECK((h % m_blocksize == 0 && w % m_blocksize == 0 && m_blocksize > 0),
"SpaceToDepth: The width and height axes size must be a multiple of ",
"squared block_size attribute value");
size_t bs = static_cast<size_t>(m_blocksize);
size_t w_flat = w / bs;
size_t h_flat = h / bs;
size_t c_high = c * bs * bs;
// First we have to disperse the data from height and width channels, then
// rearrange them so as appropriate chunks of data where close to their
// destination place. Finally squeeze data from respective dimensions.
shared_ptr<Node> flat_node = op::util::reshape(data, Shape{n, c, h_flat, bs, w_flat, bs});
flat_node = op::util::reorder_axes(flat_node, {0, 3, 5, 1, 2, 4});
return NodeVector{op::util::reshape(flat_node, Shape{n, c_high, h_flat, w_flat})};
}
shared_ptr<Node> op::SpaceToDepth::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<SpaceToDepth>(new_args.at(0), m_blocksize);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief SpaceToDepth permutes input tensor blocks of spatial data into depth dimension.
///
/// \note Values from the height and width dimensions are moved to the depth dimension.
///
/// Output node produces a tensor with shape:
/// [N, C * blocksize * blocksize, H / blocksize, W / blocksize]
class SpaceToDepth : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs a SpaceToDepth operation.
///
/// \param data - Node producing the input tensor
/// \param block_size - the size of the block of values to be moved
SpaceToDepth(const std::shared_ptr<ngraph::Node>& data, std::size_t block_size);
std::size_t get_block_size() const { return m_blocksize; }
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
std::size_t m_blocksize;
};
}
}
...@@ -22,3 +22,5 @@ NGRAPH_OP(PRelu, ngraph::op) ...@@ -22,3 +22,5 @@ NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op) NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op) NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op) NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include <util.hpp>
#include "ngraph/node.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "reshape.hpp"
using namespace ngraph;
std::shared_ptr<Node> op::util::reshape(const std::shared_ptr<Node>& node,
const AxisVector& axis_order,
const Shape& shape)
{
return std::make_shared<op::Reshape>(
node, ngraph::get_default_order(node->get_shape().size()), shape);
}
std::shared_ptr<Node> op::util::reorder_axes(const std::shared_ptr<Node>& node,
std::vector<std::size_t> axes_order = {})
{
Shape out_shape = node->get_shape();
if (axes_order.empty())
{
axes_order.resize(out_shape.size());
std::iota(std::begin(axes_order), std::end(axes_order), 0);
}
else
{
for (std::size_t i = 0; i < axes_order.size(); ++i)
{
out_shape[i] = node->get_shape().at(axes_order.at(i));
}
}
auto axis_vector = AxisVector{std::begin(axes_order), std::end(axes_order)};
return std::make_shared<op::Reshape>(node, axis_vector, out_shape);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Change shape of input tensor.
///
/// \param[in] node The node which shape will be used as input to Reshape.
/// \param[in] shape The new shape for input tensor.
///
/// \return The node representing a Reshape operation.
///
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const AxisVector& axis_order,
const Shape& shape);
inline std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const Shape& shape)
{
return reshape(node, ngraph::get_default_order(node->get_shape().size()), shape);
}
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
///
/// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes_order);
} // namespace util
} // namespace op
} // namespace ngraph
...@@ -75,7 +75,9 @@ ...@@ -75,7 +75,9 @@
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp" #include "ngraph/op/erf.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
...@@ -1912,6 +1914,7 @@ shared_ptr<runtime::Executable> ...@@ -1912,6 +1914,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::BatchMatMul: case OP_TYPEID::BatchMatMul:
case OP_TYPEID::BroadcastDistributed: case OP_TYPEID::BroadcastDistributed:
case OP_TYPEID::BroadcastLike: case OP_TYPEID::BroadcastLike:
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::DynBroadcast: case OP_TYPEID::DynBroadcast:
case OP_TYPEID::DynPad: case OP_TYPEID::DynPad:
case OP_TYPEID::DynReshape: case OP_TYPEID::DynReshape:
...@@ -1936,6 +1939,7 @@ shared_ptr<runtime::Executable> ...@@ -1936,6 +1939,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::ScalarConstantLike: case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::StopGradient: case OP_TYPEID::StopGradient:
case OP_TYPEID::Transpose: case OP_TYPEID::Transpose:
default: default:
......
...@@ -66,8 +66,10 @@ ...@@ -66,8 +66,10 @@
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -843,6 +845,12 @@ static shared_ptr<ngraph::Function> ...@@ -843,6 +845,12 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Cosh>(args[0]); node = make_shared<op::Cosh>(args[0]);
break; break;
} }
case OP_TYPEID::DepthToSpace:
{
auto block_size = node_js.at("block_size").get<size_t>();
node = make_shared<op::DepthToSpace>(args[0], block_size);
break;
}
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
{ {
auto type = read_element_type(node_js.at("type")); auto type = read_element_type(node_js.at("type"));
...@@ -1339,6 +1347,12 @@ static shared_ptr<ngraph::Function> ...@@ -1339,6 +1347,12 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Softmax>(args[0], softmax_axes); node = make_shared<op::Softmax>(args[0], softmax_axes);
break; break;
} }
case OP_TYPEID::SpaceToDepth:
{
auto block_size = node_js.at("block_size").get<size_t>();
node = make_shared<op::SpaceToDepth>(args[0], block_size);
break;
}
case OP_TYPEID::Sqrt: case OP_TYPEID::Sqrt:
{ {
node = make_shared<op::Sqrt>(args[0]); node = make_shared<op::Sqrt>(args[0]);
...@@ -1712,6 +1726,13 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1712,6 +1726,13 @@ static json write(const Node& n, bool binary_constant_data)
node["axes"] = tmp->get_axes(); node["axes"] = tmp->get_axes();
break; break;
} }
case OP_TYPEID::DepthToSpace:
{
auto tmp = dynamic_cast<const op::DepthToSpace*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
node["block_size"] = tmp->get_block_size();
break;
}
case OP_TYPEID::Divide: { break; case OP_TYPEID::Divide: { break;
} }
case OP_TYPEID::Dot: case OP_TYPEID::Dot:
...@@ -1990,6 +2011,13 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1990,6 +2011,13 @@ static json write(const Node& n, bool binary_constant_data)
node["strides"] = tmp->get_strides(); node["strides"] = tmp->get_strides();
break; break;
} }
case OP_TYPEID::SpaceToDepth:
{
auto tmp = dynamic_cast<const op::SpaceToDepth*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
node["block_size"] = tmp->get_block_size();
break;
}
case OP_TYPEID::Sqrt: { break; case OP_TYPEID::Sqrt: { break;
} }
case OP_TYPEID::StopGradient: { break; case OP_TYPEID::StopGradient: { break;
......
...@@ -271,3 +271,41 @@ NGRAPH_TEST(${BACKEND_NAME}, conv_bias_add_2d) ...@@ -271,3 +271,41 @@ NGRAPH_TEST(${BACKEND_NAME}, conv_bias_add_2d)
vector<float> expected{40, 47, 54, 61, 90, 106, 122, 138}; vector<float> expected{40, 47, 54, 61, 90, 106, 122, 138};
EXPECT_EQ(expected, read_vector<float>(result0)); EXPECT_EQ(expected, read_vector<float>(result0));
} }
NGRAPH_TEST(${BACKEND_NAME}, space_to_depth)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4, 4});
auto space_to_depth = make_shared<op::SpaceToDepth>(A, 2);
auto function = make_shared<Function>(NodeVector{space_to_depth}, ParameterVector{A});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>({0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,
11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f});
test_case.add_expected_output<float>(Shape{1, 8, 2, 2},
{
0.f, 2.f, 8.f, 10.f, 16.f, 18.f, 24.f, 26.f,
1.f, 3.f, 9.f, 11.f, 17.f, 19.f, 25.f, 27.f,
4.f, 6.f, 12.f, 14.f, 20.f, 22.f, 28.f, 30.f,
5.f, 7.f, 13.f, 15.f, 21.f, 23.f, 29.f, 31.f,
});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, depth_to_space)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 8, 2, 2});
auto depth_to_space = make_shared<op::DepthToSpace>(A, 2);
auto function = make_shared<Function>(NodeVector{depth_to_space}, ParameterVector{A});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>({
0.f, 2.f, 8.f, 10.f, 16.f, 18.f, 24.f, 26.f, 1.f, 3.f, 9.f, 11.f, 17.f, 19.f, 25.f, 27.f,
4.f, 6.f, 12.f, 14.f, 20.f, 22.f, 28.f, 30.f, 5.f, 7.f, 13.f, 15.f, 21.f, 23.f, 29.f, 31.f,
});
test_case.add_expected_output<float>(
Shape{1, 2, 4, 4}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,
11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f});
test_case.run();
}
...@@ -13405,6 +13405,24 @@ TEST(type_prop, gather) ...@@ -13405,6 +13405,24 @@ TEST(type_prop, gather)
ASSERT_EQ(G->get_shape(), out_shape); ASSERT_EQ(G->get_shape(), out_shape);
} }
TEST(type_prop, depth_to_space)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 128, 8, 8});
auto space_to_depth = make_shared<op::DepthToSpace>(A, 8);
ASSERT_EQ(space_to_depth->get_element_type(), element::f32);
ASSERT_EQ(space_to_depth->get_shape(), (Shape{1, 2, 64, 64}));
}
TEST(type_prop, space_to_depth)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
auto space_to_depth = make_shared<op::SpaceToDepth>(A, 8);
ASSERT_EQ(space_to_depth->get_element_type(), element::f32);
ASSERT_EQ(space_to_depth->get_shape(), (Shape{1, 128, 8, 8}));
}
TEST(type_prop, gather_nd_scalar_from_2d) TEST(type_prop, gather_nd_scalar_from_2d)
{ {
Shape params_shape{2, 2}; Shape params_shape{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment