Commit 029db8df authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Add Tile op (#2865)

* Add Tile op

* Added Tile op header to GPU backend

* Add dummy GPU tile emitter
parent ea4a89ec
......@@ -156,6 +156,8 @@ set (SRC
op/experimental/quantized_max_pool.hpp
op/experimental/shape_of.cpp
op/experimental/shape_of.hpp
op/experimental/tile.cpp
op/experimental/tile.hpp
op/experimental/quantized_dot.cpp
op/experimental/quantized_dot.hpp
op/experimental/quantized_dot_bias.cpp
......
......@@ -92,6 +92,7 @@
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "tile.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats)
: Op("Tile", check_single_output_args({arg, repeats}))
{
constructor_validate_and_infer_types();
}
void op::Tile::validate_and_infer_types()
{
auto arg_et = get_input_element_type(0);
// Repeats should have integer data type. For now we only allow i64
auto repeats_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
repeats_et.compatible(element::Type_t::i64),
"Tile repeats must have element type i64, but has ",
repeats_et);
auto arg_shape = get_input_partial_shape(0);
auto arg_rank = arg_shape.rank();
auto repeats_shape = get_input_partial_shape(1);
auto repeats_rank = repeats_shape.rank();
auto output_rank = Rank::dynamic();
NODE_VALIDATION_CHECK(this, repeats_rank.compatible(1), "Shape of repeats must be of rank 1");
if (arg_rank.is_static())
{
// Repeats shapes should be of form {arg_rank} or dynamic
NODE_VALIDATION_CHECK(this,
repeats_shape.compatible(PartialShape{arg_rank}),
"Arg and padding below ranks mismatch");
output_rank = arg_rank;
}
auto out_shape = PartialShape::dynamic(output_rank);
if (auto const_repeats = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
if (arg_shape.is_static())
{
auto shape = arg_shape.to_shape();
auto repeats_val = const_repeats->get_vector<int64_t>();
Shape output_shape(shape.size());
for (size_t i = 0; i < shape.size(); i++)
{
output_shape[i] = shape[i] * repeats_val[i];
}
set_output_type(0, arg_et, output_shape);
}
else
{
set_output_type(0, arg_et, out_shape);
}
}
else
{
set_output_type(0, arg_et, out_shape);
}
set_input_is_relevant_to_shape(1);
}
shared_ptr<Node> op::Tile::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Tile>(new_args.at(0), new_args.at(1));
}
// TODO: This function is not implemented!
void op::Tile::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("generate_adjoints not implemented for Tile");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Dynamic Tiling operation which repeats a tensor multiple times
/// along each dimension
class Tile : public Op
{
public:
/// \brief Perform dynamic padding of a tensor
///
/// \param arg The node producing input tensor to be padded.
/// \param repeats The node producing the per-dimension replication factor
Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
......@@ -151,5 +151,6 @@ NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op)
......@@ -73,6 +73,7 @@
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/gather.hpp"
......@@ -1436,6 +1437,11 @@ std::string runtime::gpu::GPU_Emitter::emit_DynPad(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Tile(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Transpose(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
......
......@@ -2004,6 +2004,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::ShapeOf:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::StopGradient:
case OP_TYPEID::Tile:
case OP_TYPEID::Transpose:
default:
{
......
......@@ -1361,6 +1361,7 @@ private:
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
case OP_TYPEID::Tile:
default: throw unsupported_op("Unsupported op '" + node.description() + "'");
#pragma GCC diagnostic pop
}
......
......@@ -63,6 +63,7 @@
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
......@@ -1418,6 +1419,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Tanh>(args[0]);
break;
}
case OP_TYPEID::Tile:
{
node = make_shared<op::Tile>(args[0], args[1]);
break;
}
case OP_TYPEID::TopK:
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
......@@ -2100,6 +2106,8 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::Tanh: { break;
}
case OP_TYPEID::Tile: { break;
}
case OP_TYPEID::TopK:
{
auto tmp = dynamic_cast<const op::TopK*>(&n);
......
......@@ -4295,6 +4295,15 @@ TEST(
}
}
TEST(type_prop, tile)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 8, 10});
auto param1 = op::Constant::create(element::i64, Shape{3}, {3, 4, 1});
auto top = make_shared<op::Tile>(param0, param1);
ASSERT_EQ(top->get_element_type(), element::f32);
ASSERT_EQ(top->get_shape(), (Shape{18, 32, 10}));
}
TEST(type_prop, one_hot_deduce_scalar)
{
auto param = make_shared<op::Parameter>(element::i32, Shape{});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment