Commit 16d88a7f authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Embedding fprop (#2053)

* embedding fprop

* add a new line

* type prop tests

* rename

* add a stub handler for embeddinglookup on intelgpu

* rename embedding.* to embedding_lookup

* rename tests in manifest files

* move embeddinglookup to catchall case

* fix test case breaks after merge

* add a negative test, pull up an assertion

* fix test failures
parent af2c4c7d
......@@ -58,6 +58,7 @@ set (SRC
op/dequantize.cpp
op/divide.cpp
op/dot.cpp
op/embedding_lookup.cpp
op/equal.cpp
op/exp.cpp
op/experimental/generate_mask.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/embedding_lookup.hpp"
using namespace std;
using namespace ngraph;
void op::EmbeddingLookup::validate_and_infer_types()
{
element::Type result_et = get_input_element_type(1);
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT(
this, arg1_shape.rank().is_dynamic() || static_cast<size_t>(arg1_shape.rank()) == 2)
<< "weights are expected to be a matrix";
PartialShape result_shape;
if (arg0_shape.rank().is_static())
{
std::vector<Dimension> result_dims(static_cast<size_t>(arg0_shape.rank()) + 1);
for (size_t i = 0; i < static_cast<size_t>(arg0_shape.rank()); i++)
{
result_dims[i] = arg0_shape[i];
}
result_dims[result_dims.size() - 1] =
arg1_shape.rank().is_static() ? arg1_shape[1] : Dimension::dynamic();
result_shape = PartialShape(result_dims);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::EmbeddingLookup::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<EmbeddingLookup>(new_args.at(0), new_args.at(1));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/index_reduction.hpp"
namespace ngraph
{
namespace op
{
// \brief Returns embeddings for given indices
class EmbeddingLookup : public Op
{
public:
/// \brief Constructs a EmbeddingLookup operation.
///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
/// with a row (from the weights matrix) at that index
///
/// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup(const std::shared_ptr<Node>& data, const std::shared_ptr<Node>& weights)
: Op("EmbeddingLookup", check_single_output_args({data, weights}))
{
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override
{
throw ngraph_error("Not yet implemented");
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -127,3 +127,4 @@ NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op)
......@@ -44,6 +44,7 @@ set(SRC
builder/convert_layout.cpp
builder/convolution.cpp
builder/dot.cpp
builder/embedding_lookup.cpp
builder/function_call.cpp
builder/leaky_relu.cpp
builder/lstm.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstring>
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::EmbeddingLookup)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
CPUKernelFunctor functor;
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
if (out[0].get_element_type() != element::f32 &&
out[0].get_element_type() != element::f64)
{
throw ngraph_error("Unsupported output element type");
}
auto in_shape = args[1].get_shape();
size_t element_count = shape_size(args[0].get_shape());
auto out_shape = out[0].get_shape();
auto element_type = out[0].get_element_type();
auto index_element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (index_element_type == element::f32)
{
functor = [&, in_shape, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<float, float>(
static_cast<float*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<float*>(out_tensor),
element_count,
in_shape);
};
}
else if (index_element_type == element::i32)
{
functor = [&, in_shape, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<float, int>(
static_cast<int*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<float*>(out_tensor),
element_count,
in_shape);
};
}
else
{
throw ngraph_error(
"Unsupported index type in CPU Builder for EmbeddingLookup");
}
}
else if (element_type == element::i32)
{
if (index_element_type == element::f32)
{
functor = [&, in_shape, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<int, float>(
static_cast<float*>(arg0_tensor),
static_cast<int*>(arg1_tensor),
static_cast<int*>(out_tensor),
element_count,
in_shape);
};
}
else if (index_element_type == element::i32)
{
functor = [&, in_shape, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<int, int>(
static_cast<int*>(arg0_tensor),
static_cast<int*>(arg1_tensor),
static_cast<int*>(out_tensor),
element_count,
in_shape);
};
}
else
{
throw ngraph_error(
"Unsupported index type in CPU Builder for EmbeddingLookup");
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ArgMin");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(EmbeddingLookup);
}
}
}
......@@ -45,6 +45,7 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
......@@ -2264,6 +2265,24 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::EmbeddingLookup)
{
writer.block_begin();
const ngraph::op::EmbeddingLookup* embed =
static_cast<const ngraph::op::EmbeddingLookup*>(node);
auto index_type_name = embed->get_argument(0)->get_element_type().c_type_string();
auto type_name = embed->get_element_type().c_type_string();
auto element_count = shape_size(embed->get_argument(0)->get_shape());
writer << "reference::embedding<" << type_name << "," << index_type_name << ">(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << element_count << ",\n";
writer << " {" << join(args[1].get_shape()) << "});\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sin)
{
......
......@@ -61,6 +61,7 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
......@@ -318,6 +319,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Sign), &runtime::cpu::CPU_Emitter::emit<op::Sign>},
{TI(ngraph::op::Slice), &runtime::cpu::CPU_Emitter::emit<op::Slice>},
{TI(ngraph::op::Sum), &runtime::cpu::CPU_Emitter::emit<op::Sum>},
{TI(ngraph::op::EmbeddingLookup), &runtime::cpu::CPU_Emitter::emit<op::EmbeddingLookup>},
{TI(ngraph::op::Exp), &runtime::cpu::CPU_Emitter::emit<op::Exp>},
{TI(ngraph::op::Sin), &runtime::cpu::CPU_Emitter::emit<op::Sin>},
{TI(ngraph::op::Sinh), &runtime::cpu::CPU_Emitter::emit<op::Sinh>},
......@@ -492,6 +494,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/generate_mask.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/max.hpp"
......
......@@ -52,6 +52,7 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
......@@ -631,6 +632,11 @@ void runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_EmbeddingLookup(EMIT_ARGS)
{
throw ngraph_error("EmbeddingLookup is not yet implemented for NVIDIA GPU");
}
void runtime::gpu::GPU_Emitter::emit_Equal(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Equal>(external_function, writer, node, args, out);
......
......@@ -27,6 +27,9 @@ backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_avgpool_n1_c1_hw2x2
backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
batch_norm_inference_0eps_f64
batch_norm_inference_0eps_f32
batch_norm_inference_f64
......
......@@ -73,6 +73,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
......@@ -1722,6 +1723,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
case OP_TYPEID::ShapeOf:
case OP_TYPEID::StopGradient:
case OP_TYPEID::TopK:
case OP_TYPEID::EmbeddingLookup:
{
throw unsupported_op("Unsupported op '" + op->description() +
"' in IntelGPU back end.");
......
......@@ -28,6 +28,9 @@ dequantize_int8
dequantize_int8_zero_offset
dequantize_zero_offset
divide_by_zero_int32
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
function_call
generate_mask
max_pool_3d
......
......@@ -31,6 +31,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -82,6 +83,7 @@
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
......@@ -672,6 +674,51 @@ private:
dot->get_reduction_axes_count());
break;
}
case OP_TYPEID::EmbeddingLookup:
{
const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node);
auto type = embed->get_argument(0)->get_element_type();
size_t element_count = shape_size(embed->get_argument(0)->get_shape());
if (type == element::f32)
{
reference::embedding<T, float>(static_cast<const float*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else if (type == element::f64)
{
reference::embedding<T, double>(static_cast<const double*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else if (type == element::i32)
{
reference::embedding<T, int>(static_cast<const int*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else if (type == element::i64)
{
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
element_count,
embed->get_shape());
}
else
{
throw ngraph_error(std::string("Unsupported index type ") + type.c_type_string() +
std::string("in EmbeddingLookup"));
}
break;
}
case OP_TYPEID::Equal:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <cstring>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void embedding(const U* indices,
const T* weights,
T* out,
size_t indices_count,
const Shape& out_shape)
{
size_t vec_len = out_shape.at(1);
T* out_iter = out;
for (size_t i = 0; i < indices_count; i++)
{
memcpy(out_iter,
&weights[vec_len * static_cast<size_t>(indices[i])],
sizeof(T) * vec_len);
out_iter += vec_len;
}
}
}
}
}
......@@ -42,6 +42,7 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
......@@ -714,6 +715,11 @@ static shared_ptr<ngraph::Function>
}
break;
}
case OP_TYPEID::EmbeddingLookup:
{
node = make_shared<op::EmbeddingLookup>(args[0], args[1]);
break;
}
case OP_TYPEID::Equal:
{
node = make_shared<op::Equal>(args[0], args[1]);
......@@ -1382,6 +1388,8 @@ static json write(const Node& n, bool binary_constant_data)
node["reduction_axes_count"] = tmp->get_reduction_axes_count();
break;
}
case OP_TYPEID::EmbeddingLookup: { break;
}
case OP_TYPEID::Equal: { break;
}
case OP_TYPEID::Exp: { break;
......
......@@ -118,6 +118,7 @@ set(MULTI_TEST_SRC
backend_broadcast.in.cpp
backend_comparison.in.cpp
backend_dot.in.cpp
backend_embedding_lookup.in.cpp
backend_one_hot.in.cpp
backend_pool.in.cpp
backend_reduce.in.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, embedding_lookup_4x5_reverse)
{
Shape shape{4};
Shape rshape{4, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto embed = make_shared<op::EmbeddingLookup>(A, B);
auto f0 = make_shared<Function>(NodeVector{embed}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{3, 2, 1, 0});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b,
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
auto result0 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(backend->compile(f0), {result0}, {a, b});
vector<float> expected{16, 17, 18, 19, 20, 11, 12, 13, 14, 15, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result0)));
}
NGRAPH_TEST(${BACKEND_NAME}, embedding_lookup_10x1_arbitrary)
{
Shape shape{10};
Shape rshape{10, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto embed = make_shared<op::EmbeddingLookup>(A, B);
auto f0 = make_shared<Function>(NodeVector{embed}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 1, 0, 3, 5, 4, 6, 8, 7});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto result0 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(backend->compile(f0), {result0}, {a, b});
vector<float> expected{9, 2, 1, 0, 3, 5, 4, 6, 8, 7};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result0)));
}
NGRAPH_TEST(${BACKEND_NAME}, embedding_lookup_10x1_arbitrary_index_type_int)
{
Shape shape{10};
Shape rshape{10, 1};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto embed = make_shared<op::EmbeddingLookup>(A, B);
auto f0 = make_shared<Function>(NodeVector{embed}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a, vector<int>{9, 2, 1, 0, 3, 5, 4, 6, 8, 7});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b, vector<float>{0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
auto result0 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(backend->compile(f0), {result0}, {a, b});
//vector<float> expected{9.5, 2.5, 1.5, 0.5, 3.5, 5.5, 4.5, 6.5, 8.5, 7.5};
vector<float> expected{9.5, 2.5, 1.5, 0.5, 3.5, 5.5, 4.5, 6.5, 8.5, 7.5};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result0)));
}
......@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include <memory>
using namespace std;
......@@ -2385,6 +2386,64 @@ TEST(type_prop, or_bad_arguments)
});
}
TEST(type_prop, embedding_lookup_non_matrix_weights)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4, 5});
try
{
auto bc = make_shared<op::EmbeddingLookup>(tv0_2_4_param_0, tv0_2_4_param_1);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("weights are expected to be a matrix"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, embedding_lookup_static_shapes)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{8, 10, 12});
auto weights = make_shared<op::Parameter>(element::f32, Shape{5, 10});
auto embed = make_shared<op::EmbeddingLookup>(data, weights);
ASSERT_EQ(embed->get_element_type(), element::f32);
ASSERT_EQ(embed->get_shape(), (Shape{8, 10, 12, 10}));
}
TEST(type_prop, embedding_lookup_dynamic_shape_arg0)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto weights = make_shared<op::Parameter>(element::f32, Shape{5, 10});
auto embed = make_shared<op::EmbeddingLookup>(data, weights);
ASSERT_EQ(embed->get_element_type(), element::f32);
ASSERT_TRUE(embed->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, embedding_lookup_dynamic_shape_arg1)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{8, 10, 12});
auto weights = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto embed = make_shared<op::EmbeddingLookup>(data, weights);
ASSERT_EQ(embed->get_element_type(), element::f32);
PartialShape expected{8, 10, 12, Dimension::dynamic()};
ASSERT_TRUE(embed->get_output_partial_shape(0).same_scheme(expected));
}
TEST(type_prop, embedding_lookup_shape_arg1_dynamic_embedding_length)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{8, 10, 12});
auto weights = make_shared<op::Parameter>(element::f32, PartialShape{5, Dimension::dynamic()});
auto embed = make_shared<op::EmbeddingLookup>(data, weights);
ASSERT_EQ(embed->get_element_type(), element::f32);
PartialShape expected{8, 10, 12, Dimension::dynamic()};
ASSERT_TRUE(embed->get_output_partial_shape(0).same_scheme(expected));
}
TEST(type_prop, comparison_good)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment