Commit db34286c authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

Move reshape functions from utils to builder. (#2984)

* Move reshape from utils to builder.

* Add aliases to functions in old place and describe changes.
parent c06bf6e1
...@@ -68,3 +68,14 @@ make_shared<Function>(results, parameters); ...@@ -68,3 +68,14 @@ make_shared<Function>(results, parameters);
The runtime::Tensor methods to get_tensor<> and write<T>(std::vector&) have been removed The runtime::Tensor methods to get_tensor<> and write<T>(std::vector&) have been removed
to the unit test directory under utils/test_tool.hpp read_vector and write_vector. to the unit test directory under utils/test_tool.hpp read_vector and write_vector.
## Changes to reshape op utils
Utility functions from `src/ngraph/op/util/reshape.hpp`, placed at namespace `ngraph::op::util`:
- `reshape`
- `reorder_axes`
- `transpose`
- `flatten`
Are moved to new location: `src/ngraph/builder/reshape.hpp` to namespace `ngraph::builder`.
...@@ -38,6 +38,8 @@ set (SRC ...@@ -38,6 +38,8 @@ set (SRC
builder/quantization_util.hpp builder/quantization_util.hpp
builder/reduce_ops.cpp builder/reduce_ops.cpp
builder/reduce_ops.hpp builder/reduce_ops.hpp
builder/reshape.cpp
builder/reshape.hpp
builder/split.cpp builder/split.cpp
builder/split.hpp builder/split.hpp
builder/tensor_mask.hpp builder/tensor_mask.hpp
...@@ -342,7 +344,6 @@ set (SRC ...@@ -342,7 +344,6 @@ set (SRC
op/util/index_reduction.hpp op/util/index_reduction.hpp
op/util/logical_reduction.cpp op/util/logical_reduction.cpp
op/util/logical_reduction.hpp op/util/logical_reduction.hpp
op/util/reshape.cpp
op/util/reshape.hpp op/util/reshape.hpp
op/util/unary_elementwise_arithmetic.cpp op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp op/util/unary_elementwise_arithmetic.hpp
......
...@@ -14,26 +14,25 @@ ...@@ -14,26 +14,25 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric> #include <numeric>
#include <util.hpp>
#include "ngraph/node.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "reshape.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
shared_ptr<Node> op::util::reshape(const shared_ptr<Node>& node, const Shape& shape) shared_ptr<Node> builder::reshape(const shared_ptr<Node>& node, const Shape& shape)
{ {
return make_shared<op::Reshape>(node, get_default_order(node->get_shape().size()), shape); return make_shared<op::Reshape>(node, get_default_order(node->get_shape().size()), shape);
} }
shared_ptr<Node> op::util::reorder_axes(const shared_ptr<Node>& node, shared_ptr<Node> builder::reorder_axes(const shared_ptr<Node>& node, vector<size_t> axes_order = {})
vector<size_t> axes_order = {})
{ {
Shape out_shape = node->get_shape(); Shape out_shape = node->get_shape();
if (axes_order.empty()) if (axes_order.empty())
...@@ -53,15 +52,15 @@ shared_ptr<Node> op::util::reorder_axes(const shared_ptr<Node>& node, ...@@ -53,15 +52,15 @@ shared_ptr<Node> op::util::reorder_axes(const shared_ptr<Node>& node,
return make_shared<op::Reshape>(node, axis_vector, out_shape); return make_shared<op::Reshape>(node, axis_vector, out_shape);
} }
shared_ptr<Node> op::util::transpose(const shared_ptr<Node>& node) shared_ptr<Node> builder::transpose(const shared_ptr<Node>& node)
{ {
vector<size_t> axes_order(node->get_shape().size()); vector<size_t> axes_order(node->get_shape().size());
iota(begin(axes_order), end(axes_order), 0); iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order)); reverse(begin(axes_order), end(axes_order));
return op::util::reorder_axes(node, axes_order); return builder::reorder_axes(node, axes_order);
} }
shared_ptr<Node> op::util::flatten(const shared_ptr<Node>& node, int axis) shared_ptr<Node> builder::flatten(const shared_ptr<Node>& node, int axis)
{ {
auto data_shape = node->get_shape(); auto data_shape = node->get_shape();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace builder
{
/// \brief Change shape of input tensor.
///
/// \param[in] node The node which shape will be used as input to Reshape.
/// \param[in] shape The new shape for input tensor.
///
/// \return The node representing a Reshape operation.
///
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& node, const Shape& shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
///
/// \return: New node with permuted axes.
std::shared_ptr<Node> reorder_axes(const std::shared_ptr<Node>& node,
std::vector<std::size_t> axes_order);
/// \brief Return transposed tensor (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
///
/// \return: New node with reversed dimensions.
std::shared_ptr<Node> transpose(const std::shared_ptr<Node>& node);
/// \brief Flatten the input tensor into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<Node> flatten(const std::shared_ptr<Node>& node, int axis);
} // namespace builder
} // namespace ngraph
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "flatten.hpp" #include <cinttypes>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "utils/reshape.hpp" #include "flatten.hpp"
#include "ngraph/builder/reshape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -30,12 +32,12 @@ namespace ngraph ...@@ -30,12 +32,12 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size())) ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid."; << "provided 'axis' attribute is not valid.";
return {ngraph::op::util::flatten(data, axis)}; return {ngraph::builder::flatten(data, axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include "hardmax.hpp" #include "hardmax.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp" #include "ngraph/frontend/onnx_import/utils/common.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/embedding_lookup.hpp" #include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/util/reshape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
<< " does not match the input tensor dimensions"; << " does not match the input tensor dimensions";
// reshape to 2D - "batch size" x "input feature dimensions" (NxD) // reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::op::util::flatten(input, axis); const auto coerced_tensor = ngraph::builder::flatten(input, axis);
const auto& coerced_shape = coerced_tensor->get_shape(); const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d = const std::shared_ptr<ngraph::Node> argmax_2d =
...@@ -54,7 +54,7 @@ namespace ngraph ...@@ -54,7 +54,7 @@ namespace ngraph
auto results = auto results =
std::make_shared<ngraph::op::EmbeddingLookup>(argmax_2d, eye_matrix); std::make_shared<ngraph::op::EmbeddingLookup>(argmax_2d, eye_matrix);
return {ngraph::op::util::reshape(results, input_shape)}; return {ngraph::builder::reshape(results, input_shape)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "lstm.hpp" #include "lstm.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -397,10 +398,10 @@ namespace ngraph ...@@ -397,10 +398,10 @@ namespace ngraph
// Xt*(W^T) -- for [iofc] gates. // Xt*(W^T) -- for [iofc] gates.
auto Xt_W = std::make_shared<ngraph::op::Dot>( auto Xt_W = std::make_shared<ngraph::op::Dot>(
in_x, ngraph::op::util::transpose(m_W)); in_x, ngraph::builder::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates. // Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = std::make_shared<ngraph::op::Dot>( auto Ht_R = std::make_shared<ngraph::op::Dot>(
H_t, ngraph::op::util::transpose(m_R)); H_t, ngraph::builder::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates. // Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias)); auto gates = add(Xt_W, add(Ht_R, bias));
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "reshape.hpp" #include "reshape.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "transpose.hpp" #include "transpose.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -38,8 +38,8 @@ namespace ngraph ...@@ -38,8 +38,8 @@ namespace ngraph
node.get_attribute_value<std::vector<std::size_t>>("perm", {}); node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) return {(permute_axes.empty())
? ngraph::op::util::transpose(data) ? ngraph::builder::transpose(data)
: ngraph::op::util::reorder_axes(data, permute_axes)}; : ngraph::builder::reorder_axes(data, permute_axes)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <vector> #include <vector>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "reduction.hpp" #include "reduction.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
......
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
#include <vector> #include <vector>
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
...@@ -111,7 +111,7 @@ namespace ngraph ...@@ -111,7 +111,7 @@ namespace ngraph
output_shape.push_back(axis); output_shape.push_back(axis);
} }
} }
return ngraph::op::util::reshape(node, output_shape); return ngraph::builder::reshape(node, output_shape);
} }
std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> collapse(const std::shared_ptr<ngraph::Node>& node,
...@@ -129,7 +129,7 @@ namespace ngraph ...@@ -129,7 +129,7 @@ namespace ngraph
output_shape.insert(std::end(output_shape), output_shape.insert(std::end(output_shape),
std::next(std::begin(shape), end_axis + 1), std::next(std::begin(shape), end_axis + 1),
std::end(shape)); std::end(shape));
return ngraph::op::util::reshape(node, output_shape); return ngraph::builder::reshape(node, output_shape);
} }
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
...@@ -159,7 +159,7 @@ namespace ngraph ...@@ -159,7 +159,7 @@ namespace ngraph
"Scalar value can't be derived from a node with ", "Scalar value can't be derived from a node with ",
node_shape); node_shape);
return ngraph::op::util::reshape(node, Shape{}); return ngraph::builder::reshape(node, Shape{});
} }
} // namespace reshape } // namespace reshape
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <memory> #include <memory>
#include "depth_to_space.hpp" #include "depth_to_space.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
using namespace std; using namespace std;
...@@ -71,9 +71,9 @@ NodeVector op::DepthToSpace::decompose_op() const ...@@ -71,9 +71,9 @@ NodeVector op::DepthToSpace::decompose_op() const
// First we have to disperse the data from depth channel, then rearrange them // First we have to disperse the data from depth channel, then rearrange them
// so as appropriate chunks of data where close to their destination place. // so as appropriate chunks of data where close to their destination place.
// Finally squeeze data from respective dimensions. // Finally squeeze data from respective dimensions.
shared_ptr<Node> flat_node = op::util::reshape(data, Shape{n, bs, bs, c_flat, h, w}); shared_ptr<Node> flat_node = builder::reshape(data, Shape{n, bs, bs, c_flat, h, w});
flat_node = op::util::reorder_axes(flat_node, {0, 3, 4, 1, 5, 2}); flat_node = builder::reorder_axes(flat_node, {0, 3, 4, 1, 5, 2});
return NodeVector{op::util::reshape(flat_node, Shape{n, c_flat, h * bs, w * bs})}; return NodeVector{builder::reshape(flat_node, Shape{n, c_flat, h * bs, w * bs})};
} }
shared_ptr<Node> op::DepthToSpace::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::DepthToSpace::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/gemm.hpp" #include "ngraph/op/fused/gemm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -49,15 +49,15 @@ NodeVector op::Gemm::decompose_op() const ...@@ -49,15 +49,15 @@ NodeVector op::Gemm::decompose_op() const
if (m_transA) if (m_transA)
{ {
A = ngraph::op::util::transpose(A); A = ngraph::builder::transpose(A);
} }
if (m_transB) if (m_transB)
{ {
B = ngraph::op::util::transpose(B); B = ngraph::builder::transpose(B);
} }
A = ngraph::op::util::flatten(A, 1); A = ngraph::builder::flatten(A, 1);
B = ngraph::op::util::flatten(B, 1); B = ngraph::builder::flatten(B, 1);
// A' * B' // A' * B'
std::shared_ptr<ngraph::Node> a_dot_b = std::make_shared<ngraph::op::Dot>(A, B); std::shared_ptr<ngraph::Node> a_dot_b = std::make_shared<ngraph::op::Dot>(A, B);
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
#include "grn.hpp" #include "grn.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
using namespace std; using namespace std;
...@@ -62,7 +62,7 @@ NodeVector op::GRN::decompose_op() const ...@@ -62,7 +62,7 @@ NodeVector op::GRN::decompose_op() const
{ {
Shape data_shape(4 - input_shape.size(), 1); Shape data_shape(4 - input_shape.size(), 1);
copy(begin(input_shape), end(input_shape), back_inserter(data_shape)); copy(begin(input_shape), end(input_shape), back_inserter(data_shape));
data = util::reshape(data, data_shape); data = builder::reshape(data, data_shape);
} }
// Calculate l2 norm across channels. // Calculate l2 norm across channels.
...@@ -74,7 +74,7 @@ NodeVector op::GRN::decompose_op() const ...@@ -74,7 +74,7 @@ NodeVector op::GRN::decompose_op() const
// get back original input tensor rank // get back original input tensor rank
if (input_shape.size() != 4) if (input_shape.size() != 4)
{ {
data = util::reshape(data, input_shape); data = builder::reshape(data, input_shape);
} }
return {data}; return {data};
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <iterator> #include <iterator>
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "normalize.hpp" #include "normalize.hpp"
using namespace std; using namespace std;
...@@ -94,7 +94,7 @@ NodeVector op::Normalize::decompose_op() const ...@@ -94,7 +94,7 @@ NodeVector op::Normalize::decompose_op() const
{ {
Shape data_shape(4 - input_shape.size(), 1); Shape data_shape(4 - input_shape.size(), 1);
copy(begin(input_shape), end(input_shape), back_inserter(data_shape)); copy(begin(input_shape), end(input_shape), back_inserter(data_shape));
data = util::reshape(data, data_shape); data = builder::reshape(data, data_shape);
} }
// Calculate norm over CHW axes. // Calculate norm over CHW axes.
...@@ -128,7 +128,7 @@ NodeVector op::Normalize::decompose_op() const ...@@ -128,7 +128,7 @@ NodeVector op::Normalize::decompose_op() const
// get back original input tensor rank // get back original input tensor rank
if (input_shape.size() != 4) if (input_shape.size() != 4)
{ {
data = util::reshape(data, input_shape); data = builder::reshape(data, input_shape);
} }
return {data}; return {data};
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include "ngraph/node.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "space_to_depth.hpp" #include "space_to_depth.hpp"
...@@ -73,9 +72,9 @@ NodeVector op::SpaceToDepth::decompose_op() const ...@@ -73,9 +72,9 @@ NodeVector op::SpaceToDepth::decompose_op() const
// First we have to disperse the data from height and width channels, then // First we have to disperse the data from height and width channels, then
// rearrange them so as appropriate chunks of data where close to their // rearrange them so as appropriate chunks of data where close to their
// destination place. Finally squeeze data from respective dimensions. // destination place. Finally squeeze data from respective dimensions.
shared_ptr<Node> flat_node = op::util::reshape(data, Shape{n, c, h_flat, bs, w_flat, bs}); shared_ptr<Node> flat_node = builder::reshape(data, Shape{n, c, h_flat, bs, w_flat, bs});
flat_node = op::util::reorder_axes(flat_node, {0, 3, 5, 1, 2, 4}); flat_node = builder::reorder_axes(flat_node, {0, 3, 5, 1, 2, 4});
return NodeVector{op::util::reshape(flat_node, Shape{n, c_high, h_flat, w_flat})}; return NodeVector{builder::reshape(flat_node, Shape{n, c_high, h_flat, w_flat})};
} }
shared_ptr<Node> op::SpaceToDepth::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::SpaceToDepth::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#pragma once #pragma once
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/fused_op.hpp"
namespace ngraph namespace ngraph
......
...@@ -16,9 +16,13 @@ ...@@ -16,9 +16,13 @@
#pragma once #pragma once
#include "ngraph/axis_vector.hpp" #include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/util.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,7 +38,10 @@ namespace ngraph ...@@ -34,7 +38,10 @@ namespace ngraph
/// \return The node representing a Reshape operation. /// \return The node representing a Reshape operation.
/// ///
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const Shape& shape); const Shape& shape)
{
return builder::reshape(node, shape);
}
/// \brief Permute axes according to specified axes_order parameter. /// \brief Permute axes according to specified axes_order parameter.
/// ///
...@@ -43,14 +50,20 @@ namespace ngraph ...@@ -43,14 +50,20 @@ namespace ngraph
/// ///
/// \return: New node with permuted axes. /// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes_order); std::vector<std::size_t> axes_order)
{
return builder::reorder_axes(node, axes_order);
}
/// \brief Return transposed tensor (with axes in reversed order). /// \brief Return transposed tensor (with axes in reversed order).
/// ///
/// \param node Input tensor we want to transpose /// \param node Input tensor we want to transpose
/// ///
/// \return: New node with reversed dimensions. /// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node); std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node)
{
return builder::transpose(node);
}
/// \brief Flatten the input tensor into a 2D matrix. /// \brief Flatten the input tensor into a 2D matrix.
/// ///
...@@ -59,7 +72,10 @@ namespace ngraph ...@@ -59,7 +72,10 @@ namespace ngraph
/// ///
/// \return The new node being a 2D matrix representing flattened input node. /// \return The new node being a 2D matrix representing flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis); int axis)
{
return builder::flatten(node, axis);
}
} // namespace util } // namespace util
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment