Commit f66fd5a9 authored by baojun's avatar baojun Committed by Sang Ik Lee

Add fluid lookup_table_v2 op (#4185)

* add lookup table v2 fused op placeholder

* implement frop

* fix fprop and add ut

* turn off fluid test build

* add missed ut

* remove get_output_shared_ptr

* implement bprop
parent 90c2f5bd
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
target_sources (ngraph PRIVATE target_sources (ngraph PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/operators/layout_converter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/layout_converter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/layout_converter.hpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/layout_converter.hpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/lookup_table.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/lookup_table.hpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.hpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/matmul.hpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/pool.cpp ${CMAKE_CURRENT_SOURCE_DIR}/operators/pool.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/frontend/fluid/operators/lookup_table.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph::fluid;
constexpr NodeTypeInfo LookupTable2::type_info;
LookupTable2::LookupTable2(const Output<Node>& w,
const Output<Node>& ids,
const int64_t padding_idx)
: FusedOp({w, ids})
, m_padding_idx(padding_idx)
{
constructor_validate_and_infer_types();
}
NodeVector LookupTable2::decompose_op() const
{
auto w = input_value(0);
auto ids = input_value(1);
auto padding_idx = get_padding_idx();
auto table_shape = get_input_shape(0);
NODE_VALIDATION_CHECK(
this, table_shape.size() == 2, "The dimension of look up table must be 2");
auto row_number = table_shape[0];
auto masked_w = w;
if (padding_idx != -1)
{
vector<size_t> mask(row_number, 1);
mask[padding_idx] = 0;
auto mask_node = make_shared<op::Constant>(w.get_element_type(), Shape{row_number}, mask);
auto mask_bcast = make_shared<op::Broadcast>(mask_node, table_shape, AxisSet{1});
masked_w = w * mask_bcast;
}
auto out = make_shared<op::Gather>(masked_w, ids);
return {out};
}
shared_ptr<Node> LookupTable2::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LookupTable2>(new_args.at(0), new_args.at(1), get_padding_idx());
}
void LookupTable2::pre_validate_and_infer_types()
{
auto pshape_w = get_input_partial_shape(0);
auto pshape_ids = get_input_partial_shape(1);
if (pshape_w.is_dynamic() || pshape_ids.is_dynamic())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
}
constexpr NodeTypeInfo LookupTable2Grad::type_info;
LookupTable2Grad::LookupTable2Grad(const Output<Node>& w,
const Output<Node>& ids,
const Output<Node>& dout)
: FusedOp({w, ids, dout})
{
constructor_validate_and_infer_types();
}
void LookupTable2Grad::pre_validate_and_infer_types()
{
if (get_input_partial_shape(0).is_dynamic() || get_input_partial_shape(1).is_dynamic() ||
get_input_partial_shape(2).is_dynamic())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
}
shared_ptr<Node> LookupTable2Grad::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LookupTable2Grad>(new_args.at(0), new_args.at(1), new_args.at(2));
}
NodeVector LookupTable2Grad::decompose_op() const
{
auto w = input_value(0);
auto ids = input_value(1);
auto dout = input_value(2);
auto shape_w = get_input_shape(0);
auto w0 = op::Constant::create(dout.get_element_type(), shape_w, {0});
auto dw = make_shared<op::ScatterAdd>(w0, ids, dout);
return {dw};
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace fluid
{
/// \brief Fluid lookup_table2
class NGRAPH_API LookupTable2 : public ngraph::op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"FluidLookupTable2", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LookupTable2() = default;
/// \brief Constructs a LookupTable2 operation.
///
/// \param w Input weight table
/// \param ids look up ids
LookupTable2(const Output<Node>& w, const Output<Node>& ids, const int64_t padding_idx);
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int64_t get_padding_idx() const { return m_padding_idx; }
protected:
int64_t m_padding_idx{-1};
};
/// \brief Fluid reduce_sum_grad
class NGRAPH_API LookupTable2Grad : public ngraph::op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"FluidLookupTable2Grad", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LookupTable2Grad() = default;
/// \brief Constructs a LookupTable2Grad operation.
///
/// \param w Input weight table
/// \param ids Input lookup ids
/// \param dout Input delta
LookupTable2Grad(const Output<Node>& w,
const Output<Node>& ids,
const Output<Node>& dout);
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
...@@ -41,6 +41,7 @@ endif() ...@@ -41,6 +41,7 @@ endif()
set(SRC set(SRC
main.cpp main.cpp
reduce_sum.cpp reduce_sum.cpp
lookup_table.cpp
) )
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/../../../../../../test) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/../../../../../../test)
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#undef IN_NGRAPH_LIBRARY
#include "gtest/gtest.h"
#include "ngraph/frontend/fluid/operators/lookup_table.hpp"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
static std::mt19937_64 random_generator;
using namespace std;
using namespace ngraph;
static string s_manifest = "test.manifest";
NGRAPH_TEST(CPU, fluid_lookup_table_v2)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2, 3, 4};
Shape out_shape{2, 2, 3, 4, 2};
auto P = make_shared<op::Parameter>(element::u8, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<fluid::LookupTable2>(P, I, -1);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto p = backend->create_tensor(element::u8, params_shape);
copy_data(p, vector<uint8_t>{10, 11, 20, 21, 30, 31});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2,
0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2,
0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 2});
auto result = backend->create_tensor(element::u8, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close(
(vector<uint8_t>{10, 11, 20, 21, 20, 21, 30, 31, 10, 11, 20, 21, 20, 21, 30, 31,
10, 11, 20, 21, 20, 21, 30, 31, 10, 11, 20, 21, 20, 21, 30, 31,
10, 11, 20, 21, 20, 21, 30, 31, 10, 11, 20, 21, 20, 21, 30, 31,
10, 11, 20, 21, 20, 21, 30, 31, 10, 11, 20, 21, 20, 21, 30, 31,
10, 11, 20, 21, 20, 21, 30, 31, 10, 11, 20, 21, 20, 21, 30, 31,
10, 11, 20, 21, 20, 21, 30, 31, 10, 11, 20, 21, 20, 21, 30, 31}),
read_vector<uint8_t>(result)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment