Commit ccc97f3b authored by baojun's avatar baojun Committed by Sang Ik Lee

Add fluid layout converter (#4061)

* Add fluid layout converter

* remove const
parent 1edba878
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
# Add files here # Add files here
target_sources (ngraph PRIVATE target_sources (ngraph PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.hpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/layout_converter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/layout_converter.hpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.hpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/pool.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/pool.hpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/pool.hpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/reduce_sum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/reduce_sum.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/frontend/fluid/operators/layout_converter.hpp"
#include "ngraph/op/reshape.hpp"
using namespace std;
using namespace ngraph::fluid;
constexpr NodeTypeInfo LayoutConverter::type_info;
LayoutConverter::LayoutConverter(const Output<Node>& x, const int mode)
: FusedOp({x})
, m_mode(mode)
{
constructor_validate_and_infer_types();
}
NodeVector LayoutConverter::decompose_op() const
{
auto x = input_value(0);
auto x_shape = get_input_shape(0);
int mode = get_mode();
NODE_VALIDATION_CHECK(this, x_shape.size() == 4, "Input rank is not 4");
AxisVector axis_vec;
switch (mode)
{
case 1: axis_vec = {0, 3, 1, 2}; break;
case 2: axis_vec = {0, 2, 3, 1}; break;
default: throw ngraph_error("Unsupported layout convert mode");
}
Shape out_shape = x_shape;
for (size_t i = 0; i < 4; ++i)
{
out_shape[i] = x_shape[axis_vec[i]];
}
return {make_shared<op::Reshape>(x, axis_vec, out_shape)};
}
shared_ptr<Node> LayoutConverter::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LayoutConverter>(new_args.at(0), get_mode());
}
void LayoutConverter::validate_and_infer_types()
{
auto shape = get_input_partial_shape(0);
if (shape.is_dynamic())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
else
{
FusedOp::validate_and_infer_types();
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace fluid
{
/// \brief Fluid layout converter
class NGRAPH_API LayoutConverter : public ngraph::op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"FluidLayoutConverter", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LayoutConverter() = default;
/// \brief Constructs a LayoutConverter operation.
///
/// \param x Input x
/// \param mode : 1. nhwc->nchw, 2 hchw->nhwc
LayoutConverter(const Output<Node>& x, const int mode);
virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_mode() const { return m_mode; }
protected:
int m_mode;
};
} // namespace fluid
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment