Commit 59632bac authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

gather, gather_nd (#2742)

* Temp.

* Put all the dummy files.

* Remove some compile errors.

* WIP: Add gather and gather_nd kernels.

* Temp save.

* Update comments for gather.

* Implement reference gather.

* Validate and infer shape.

* Style.

* Fix compile issues.

* Add serializer support.

* Fix interpreter compilation issues.

* WIP: Add UT

* WIP: Add UT

* gather_nd UT passing.

* Fix gather with no axis.

* Fix gather issue.

* Update unit_test.manifest for backends and add gather, gather_nd  support for generic cpu.

* Add type_prop tests.

* Add CPU builders.

* Fix codegen.

* Make some UT numbers more readable.

* Style.

* Update Copyright Year

* Update Copyright Year

* Fix Typo.

* Remove unused variable.

* fix nv gpu build error

* Fix intel gpu compilation.

* Add basic docstring.

* Allow 1D indices for gather_nd.

* Allow scalar indices for gather.

* Update unit_test manifest files.

* Style.

* Add indices element type check and add failing type_prop checks.

* Update docstring.

* Fix incorrect test names in unit_test.manifest

* Missing header
parent a68cddb7
......@@ -165,6 +165,10 @@ set (SRC
op/experimental/transpose.hpp
op/floor.cpp
op/floor.hpp
op/gather.cpp
op/gather.hpp
op/gather_nd.cpp
op/gather_nd.hpp
op/get_output_element.cpp
op/get_output_element.hpp
op/greater.cpp
......
......@@ -96,6 +96,8 @@
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/gather.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
static int PARAMS = 0;
static int INDICES = 1;
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Gather>(new_args.at(PARAMS), new_args.at(INDICES), m_axis);
}
void op::Gather::validate_and_infer_types()
{
element::Type result_et = get_input_element_type(PARAMS);
element::Type indices_et = get_input_element_type(INDICES);
const PartialShape& params_shape = get_input_partial_shape(PARAMS);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
"Indices element type must be i64 or i32");
// params rank must be at least (axis + 1)
// indices value must be in range [0, params.shape[axis]).
// output rank is rank(params) + rank(indices) - 1
NODE_VALIDATION_CHECK(this,
params_shape.rank().is_dynamic() ||
static_cast<size_t>(params_shape.rank()) >
static_cast<size_t>(m_axis),
"params rank is expected to be at least axis + 1");
PartialShape result_shape;
if (params_shape.rank().is_static() && indices_shape.rank().is_static())
{
std::vector<Dimension> result_dims(static_cast<size_t>(params_shape.rank()) +
static_cast<size_t>(indices_shape.rank()) - 1);
size_t i = 0;
for (; i < static_cast<size_t>(m_axis); i++)
{
result_dims[i] = params_shape[i];
}
for (size_t j = 0; j < static_cast<size_t>(indices_shape.rank()); i++, j++)
{
result_dims[i] = indices_shape[j];
}
for (size_t j = static_cast<size_t>(m_axis) + 1;
j < static_cast<size_t>(params_shape.rank());
i++, j++)
{
result_dims[i] = params_shape[j];
}
result_shape = PartialShape(result_dims);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_type(0, result_et, result_shape);
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Gather slices from axis of params according to indices
class Gather : public Op
{
public:
// \param params The tensor from which slices are gathered
// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params,
const std::shared_ptr<Node>& indices,
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override
{
throw ngraph_error("Not yet implemented");
}
size_t get_axis() const { return m_axis; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
size_t m_axis;
};
}
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
static int PARAMS = 0;
static int INDICES = 1;
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<GatherND>(new_args.at(PARAMS), new_args.at(INDICES));
}
void op::GatherND::validate_and_infer_types()
{
element::Type result_et = get_input_element_type(PARAMS);
element::Type indices_et = get_input_element_type(INDICES);
const PartialShape& params_shape = get_input_partial_shape(PARAMS);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
"Indices element type must be i64 or i32");
NODE_VALIDATION_CHECK(this,
indices_shape.rank().is_dynamic() ||
static_cast<size_t>(indices_shape.rank()) >= 1,
"indices rank is expected to be at least 1");
NODE_VALIDATION_CHECK(this,
params_shape.rank().is_dynamic() ||
static_cast<size_t>(params_shape.rank()) >= 1,
"params rank is expected to be at least 1");
NODE_VALIDATION_CHECK(
this,
params_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
static_cast<size_t>(indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]) <=
static_cast<size_t>(params_shape.rank()),
"last dimension of indices can be at most the rank of params");
PartialShape result_shape;
if (params_shape.rank().is_static() && indices_shape.rank().is_static())
{
std::vector<Dimension> result_dims(
static_cast<size_t>(indices_shape.rank()) - 1 +
static_cast<size_t>(params_shape.rank()) -
static_cast<size_t>(indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]));
size_t i = 0;
for (; i < static_cast<size_t>(indices_shape.rank()) - 1; i++)
{
result_dims[i] = indices_shape[i];
}
for (size_t j = static_cast<size_t>(
indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]);
j < static_cast<size_t>(params_shape.rank());
i++, j++)
{
result_dims[i] = params_shape[j];
}
result_shape = PartialShape(result_dims);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_type(0, result_et, result_shape);
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Gather slices from params with shapes given by indices
class GatherND : public Op
{
public:
// \param params The tensor from which slices are gathered
// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices}))
{
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override
{
throw ngraph_error("Not yet implemented");
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -90,6 +90,8 @@ NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(Gather, ngraph::op)
NGRAPH_OP(GatherND, ngraph::op)
NGRAPH_OP(GenerateMask, ngraph::op)
NGRAPH_OP(GetOutputElement, ngraph::op)
NGRAPH_OP(Greater, ngraph::op)
......
......@@ -47,6 +47,8 @@ set(SRC
builder/dot.cpp
builder/embedding_lookup.cpp
builder/erf.cpp
builder/gather.cpp
builder/gather_nd.cpp
builder/leaky_relu.cpp
builder/lstm.cpp
builder/lrn.cpp
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstring>
#include "ngraph/op/gather.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/gather.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Gather)
{
auto& functors = external_function->get_functors();
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
CPUKernelFunctor functor;
auto& params_tensor = external_function->get_tensor_data(args[0].get_name());
auto& indices_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64;
auto axis = gather->get_axis();
auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&, params_shape, indices_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int64_t>(
static_cast<float*>(params_tensor),
static_cast<int64_t*>(indices_tensor),
static_cast<float*>(out_tensor),
params_shape,
indices_shape,
out_shape,
axis);
};
}
else
{
functor = [&, params_shape, indices_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int32_t>(
static_cast<float*>(params_tensor),
static_cast<int32_t*>(indices_tensor),
static_cast<float*>(out_tensor),
params_shape,
indices_shape,
out_shape,
axis);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&, params_shape, indices_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int64_t>(
static_cast<double*>(params_tensor),
static_cast<int64_t*>(indices_tensor),
static_cast<double*>(out_tensor),
params_shape,
indices_shape,
out_shape,
axis);
};
}
else
{
functor = [&, params_shape, indices_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int32_t>(
static_cast<double*>(params_tensor),
static_cast<int32_t*>(indices_tensor),
static_cast<double*>(out_tensor),
params_shape,
indices_shape,
out_shape,
axis);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for Gather");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(Gather);
}
}
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstring>
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::GatherND)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
auto& params_tensor = external_function->get_tensor_data(args[0].get_name());
auto& indices_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64;
auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&, params_shape, indices_shape, out_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<float, int64_t>(
static_cast<float*>(params_tensor),
static_cast<int64_t*>(indices_tensor),
static_cast<float*>(out_tensor),
params_shape,
indices_shape,
out_shape);
};
}
else
{
functor = [&, params_shape, indices_shape, out_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<float, int32_t>(
static_cast<float*>(params_tensor),
static_cast<int32_t*>(indices_tensor),
static_cast<float*>(out_tensor),
params_shape,
indices_shape,
out_shape);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&, params_shape, indices_shape, out_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<double, int64_t>(
static_cast<double*>(params_tensor),
static_cast<int64_t*>(indices_tensor),
static_cast<double*>(out_tensor),
params_shape,
indices_shape,
out_shape);
};
}
else
{
functor = [&, params_shape, indices_shape, out_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather_nd<double, int32_t>(
static_cast<double*>(params_tensor),
static_cast<int32_t*>(indices_tensor),
static_cast<double*>(out_tensor),
params_shape,
indices_shape,
out_shape);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for GatherND");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(GatherND);
}
}
}
......@@ -63,6 +63,8 @@
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......@@ -1807,6 +1809,50 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Gather)
{
auto gather = static_cast<const ngraph::op::Gather*>(node);
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::gather<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " " << gather->get_axis() << ");\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::GatherND)
{
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::gather_nd<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
......
......@@ -79,6 +79,8 @@
#include "ngraph/op/experimental/quantized_max_pool.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......@@ -312,6 +314,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Divide), &runtime::cpu::CPU_Emitter::emit<op::Divide>},
{TI(ngraph::op::Equal), &runtime::cpu::CPU_Emitter::emit<op::Equal>},
{TI(ngraph::op::Erf), &runtime::cpu::CPU_Emitter::emit<op::Erf>},
{TI(ngraph::op::Gather), &runtime::cpu::CPU_Emitter::emit<op::Gather>},
{TI(ngraph::op::GatherND), &runtime::cpu::CPU_Emitter::emit<op::GatherND>},
{TI(ngraph::op::GetOutputElement), &runtime::cpu::CPU_Emitter::emit<op::GetOutputElement>},
{TI(ngraph::op::Greater), &runtime::cpu::CPU_Emitter::emit<op::Greater>},
{TI(ngraph::op::GreaterEq), &runtime::cpu::CPU_Emitter::emit<op::GreaterEq>},
......@@ -539,6 +543,8 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/max.hpp"
......
......@@ -37,6 +37,7 @@
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
......@@ -90,6 +91,8 @@
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
......@@ -767,6 +770,61 @@ private:
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
break;
}
case OP_TYPEID::Gather:
{
const op::Gather* gather = static_cast<const op::Gather*>(&node);
if (node.get_input_element_type(1) == element::i64)
{
reference::gather<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
gather->get_axis());
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::gather<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
gather->get_axis());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::GatherND:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::Greater:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -75,6 +75,8 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......@@ -639,6 +641,16 @@ std::string runtime::gpu::GPU_Emitter::emit_Floor(EMIT_ARGS)
return emit_elementwise<ngraph::op::Floor>(compiled_function, function_name, node, args, out);
}
std::string runtime::gpu::GPU_Emitter::emit_Gather(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_GatherND(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_GenerateMask(EMIT_ARGS)
{
throw ngraph_error("GenerateMask is not supported yet on NVIDIA GPU");
......
......@@ -137,3 +137,19 @@ create_tensor_2_output
erf
zero_sized_erf
model_erf
gather_no_axis
gather
gather_nd_scalar_from_2d
gather_nd_1d_from_2d
gather_nd_scalar_from_3d
gather_nd_1d_from_3d
gather_nd_2d_from_3d
gather_nd_batch_scalar_from_2d
gather_nd_batch_1d_from_2d
gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
......@@ -2008,6 +2008,8 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::DynReshape:
case OP_TYPEID::DynSlice:
case OP_TYPEID::Erf:
case OP_TYPEID::Gather:
case OP_TYPEID::GatherND:
case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd:
......
......@@ -44,3 +44,18 @@ pad_reflect_2d_with_neg
# Not implemented
erf
zero_sized_erf
gather_no_axis
gather
gather_nd_scalar_from_2d
gather_nd_1d_from_2d
gather_nd_scalar_from_3d
gather_nd_1d_from_3d
gather_nd_2d_from_3d
gather_nd_batch_scalar_from_2d
gather_nd_batch_1d_from_2d
gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
......@@ -40,6 +40,7 @@
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
......@@ -94,6 +95,8 @@
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
......@@ -812,6 +815,61 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Gather:
{
const op::Gather* gather = static_cast<const op::Gather*>(&node);
if (node.get_input_element_type(1) == element::i64)
{
reference::gather<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
gather->get_axis());
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::gather<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
gather->get_axis());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::GatherND:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::Greater:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -102,3 +102,18 @@ embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
floor_int32
gather_no_axis
gather
gather_nd_scalar_from_2d
gather_nd_1d_from_2d
gather_nd_scalar_from_3d
gather_nd_1d_from_3d
gather_nd_2d_from_3d
gather_nd_batch_scalar_from_2d
gather_nd_batch_1d_from_2d
gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
// Implement gather by calling gather_nd on sub-problems
// # prepare constant shapes for tensors used for sub problems
// indices'.shape = indices.shape[-1] + [1]
// params'.shape = params.shape[axis:]
// out'.shape = params'.shape
// out'.shape[0] = indices.shape[-1]
// # call sub-problems
// foreach (params_index, out_index) in outer "axis" dimensions
// # params_prime is shared by inner loop
// params' = param[params_index] # rank(params') == rank(params) - axis
// foreach indices_index in outer N-1 dimensions
// indices' = indices[indices_index] # rank(indices') == 2
// out_index = out_index + indices_index
// out' = out[out_index] # rank(out') == rank(params')
// gather_nd(params', indices'', out')
template <typename T, typename U>
void gather(T* params,
U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape,
size_t axis)
{
using namespace std;
// prepare shape of params_prime (remove first "axis" dimensions)
Shape params_prime_shape(params_shape);
params_prime_shape.erase(params_prime_shape.begin(),
params_prime_shape.begin() + axis);
// prepare shape of indices_prime
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
Shape indices_prime_shape;
// prepare shape of out_prime (same as params_prime except for first dim)
Shape out_prime_shape(params_prime_shape);
if (indices_ndim > 0)
{
out_prime_shape[0] = indices_shape[indices_ndim - 1];
indices_prime_shape.emplace_back(indices_shape[indices_ndim - 1]);
}
else
{
out_prime_shape[0] = 1;
}
indices_prime_shape.emplace_back(1);
// Create a CoordinateTransform for "out" that visits the outer "axis" dimensions
size_t out_ndim = static_cast<size_t>(out_shape.size());
Coordinate out_outer_start_corner(out_ndim, 0);
Coordinate out_outer_end_corner(out_shape);
for (size_t i = axis; i < out_ndim; i++)
{
out_outer_end_corner[i] = 1;
}
Strides out_outer_strides(out_ndim, 1);
AxisVector out_outer_axis_order(out_ndim);
std::iota(out_outer_axis_order.begin(), out_outer_axis_order.end(), 0);
CoordinateTransform out_outer_transform(out_shape,
out_outer_start_corner,
out_outer_end_corner,
out_outer_strides,
out_outer_axis_order);
// Create a CoordinateTransform for "params" that visits the outer "axis" dimensions
size_t params_ndim = static_cast<size_t>(params_shape.size());
Coordinate params_outer_start_corner(params_ndim, 0);
Coordinate params_outer_end_corner(params_shape);
for (size_t i = axis; i < params_ndim; i++)
{
params_outer_end_corner[i] = 1;
}
Strides params_outer_strides(params_ndim, 1);
AxisVector params_outer_axis_order(params_ndim);
std::iota(params_outer_axis_order.begin(), params_outer_axis_order.end(), 0);
CoordinateTransform params_outer_transform(params_shape,
params_outer_start_corner,
params_outer_end_corner,
params_outer_strides,
params_outer_axis_order);
// Create a CoordinateTransform for "indices" that visits only the first element along inner most axis
Coordinate indices_outer_start_corner(indices_ndim, 0);
Coordinate indices_outer_end_corner(indices_shape);
if (indices_ndim > 0)
{
indices_outer_end_corner[indices_ndim - 1] = 1;
}
Strides indices_outer_strides(indices_ndim, 1);
AxisVector indices_outer_axis_order(indices_ndim);
std::iota(indices_outer_axis_order.begin(), indices_outer_axis_order.end(), 0);
CoordinateTransform indices_outer_transform(indices_shape,
indices_outer_start_corner,
indices_outer_end_corner,
indices_outer_strides,
indices_outer_axis_order);
// Create an inner CoordinateTransfrom for "out"
size_t out_inner_ndim = out_ndim - axis;
Shape out_inner_shape(out_shape);
out_inner_shape.erase(out_inner_shape.begin(), out_inner_shape.begin() + axis);
Coordinate out_inner_start_corner(out_inner_ndim, 0);
Coordinate out_inner_end_corner(out_inner_shape);
if (indices_ndim > 0)
{
out_inner_end_corner[indices_ndim - 1] = 1;
}
for (size_t i = indices_ndim; i < out_inner_ndim; i++)
{
out_inner_end_corner[i] = 1;
}
Strides out_inner_strides(out_inner_ndim, 1);
AxisVector out_inner_axis_order(out_inner_ndim);
std::iota(out_inner_axis_order.begin(), out_inner_axis_order.end(), 0);
CoordinateTransform out_inner_transform(out_inner_shape,
out_inner_start_corner,
out_inner_end_corner,
out_inner_strides,
out_inner_axis_order);
auto out_outer_coord_iter = out_outer_transform.begin();
for (const Coordinate& params_outer_coord : params_outer_transform)
{
T* params_prime = &params[params_outer_transform.index(params_outer_coord)];
T* out_outer = &out[out_outer_transform.index(*out_outer_coord_iter)];
auto out_inner_coord_iter = out_inner_transform.begin();
for (const Coordinate& indices_outer_coord : indices_outer_transform)
{
U* indices_prime =
&indices[indices_outer_transform.index(indices_outer_coord)];
T* out_prime = &out_outer[out_inner_transform.index(*out_inner_coord_iter)];
gather_nd<T, U>(params_prime,
indices_prime,
out_prime,
params_prime_shape,
indices_prime_shape,
out_prime_shape);
out_inner_coord_iter++;
}
out_outer_coord_iter++;
}
}
}
}
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
// foreach leaf_vector_index in indices.shape[:-1]
// vector = indices[leaf_vector_index]
// out[leaf_vector_index:] = params[vector]
template <typename T, typename U>
void gather_nd(const T* params,
const U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape)
{
using namespace std;
// Create a CoordinateTransform for "indices" that visits only the first element along inner most axis
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
Coordinate indices_outer_start_corner(indices_ndim, 0);
Coordinate indices_outer_end_corner(indices_shape);
size_t slice_rank = indices_shape[indices_ndim - 1];
indices_outer_end_corner[indices_ndim - 1] = 1;
Strides indices_strides(indices_ndim, 1);
AxisVector indices_axis_order(indices_ndim);
std::iota(indices_axis_order.begin(), indices_axis_order.end(), 0);
CoordinateTransform indices_outer_transform(indices_shape,
indices_outer_start_corner,
indices_outer_end_corner,
indices_strides,
indices_axis_order);
// Create a matching CoordinateTransform for "out" that visits the same outer coordinates
size_t out_ndim = static_cast<size_t>(out_shape.size());
Coordinate out_start_corner(out_ndim, 0);
Coordinate out_end_corner(out_shape);
for (size_t i = indices_ndim - 1; i < out_ndim; i++)
{
out_end_corner[i] = 1;
}
Strides out_strides(out_ndim, 1);
AxisVector out_axis_order(out_ndim);
std::iota(out_axis_order.begin(), out_axis_order.end(), 0);
CoordinateTransform out_transform(
out_shape, out_start_corner, out_end_corner, out_strides, out_axis_order);
size_t params_ndim = static_cast<size_t>(params_shape.size());
Strides params_strides(params_ndim, 1);
AxisVector params_axis_order(params_ndim);
std::iota(params_axis_order.begin(), params_axis_order.end(), 0);
// Gather slices from "params" and copy to "out"
auto out_coord_iter = out_transform.begin();
for (const Coordinate& indices_coord : indices_outer_transform)
{
Coordinate params_start_corner(params_ndim, 0);
Coordinate params_end_corner(params_shape);
auto indices_index = indices_outer_transform.index(indices_coord);
for (size_t i = 0; i < slice_rank; i++)
{
U index = indices[indices_index];
params_start_corner[i] = index;
params_end_corner[i] = index + 1;
indices_index++;
}
CoordinateTransform params_transform(params_shape,
params_start_corner,
params_end_corner,
params_strides,
params_axis_order);
auto out_index = out_transform.index(*out_coord_iter);
for (const Coordinate& params_coord : params_transform)
{
out[out_index] = params[params_transform.index(params_coord)];
out_index++;
}
out_coord_iter++;
}
}
}
}
}
......@@ -67,6 +67,8 @@
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......@@ -883,6 +885,17 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Floor>(args[0]);
break;
}
case OP_TYPEID::Gather:
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Gather>(args[0], args[1], axis);
break;
}
case OP_TYPEID::GatherND:
{
node = make_shared<op::GatherND>(args[0], args[1]);
break;
}
case OP_TYPEID::GenerateMask:
{
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
......@@ -1679,6 +1692,14 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::Floor: { break;
}
case OP_TYPEID::Gather:
{
auto tmp = dynamic_cast<const op::Gather*>(&n);
node["axis"] = tmp->get_axis();
break;
}
case OP_TYPEID::GatherND: { break;
}
case OP_TYPEID::GetOutputElement:
{
auto tmp = dynamic_cast<const op::GetOutputElement*>(&n);
......
......@@ -143,6 +143,7 @@ set(MULTI_TEST_SRC
backend_dot.in.cpp
backend_embedding_lookup.in.cpp
backend_fused_op.in.cpp
backend_gather.in.cpp
backend_one_hot.in.cpp
backend_pool.in.cpp
backend_reshape.in.cpp
......
This diff is collapsed.
......@@ -13245,6 +13245,267 @@ TEST(type_prop, prelu)
ASSERT_EQ(prelu->get_shape(), prelu_shape);
}
TEST(type_prop, gather_no_axis)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather)
{
Shape params_shape{3, 3};
Shape indices_shape{1, 2};
Shape out_shape{3, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::Gather>(P, I, 1);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_scalar_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_1d_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_scalar_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 3};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_1d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_2d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{1, 1};
Shape out_shape{1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_scalar_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1, 2};
Shape out_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_1d_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_scalar_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2, 3};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_1d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_2d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
}
TEST(type_prop, gather_fail_params_rank)
{
Shape params_shape{3, 3};
Shape indices_shape{1, 2};
Shape out_shape{3, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
try
{
auto G = make_shared<op::Gather>(P, I, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect params rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("params rank is expected to be at least axis + 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_fail_indices_element_type)
{
Shape params_shape{3, 3};
Shape indices_shape{1, 2};
Shape out_shape{3, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
try
{
auto G = make_shared<op::Gather>(P, I, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_params_rank)
{
Shape params_shape{};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
try
{
auto G = make_shared<op::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect params rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("params rank is expected to be at least 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_indices_rank)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
try
{
auto G = make_shared<op::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("indices rank is expected to be at least 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_indices_element_type)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
try
{
auto G = make_shared<op::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_bias_2d_deduce)
{
// Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment