Commit 1a44d7f8 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add support for element-type-polymorphic operators (#157)

* Add support for some polymorphic operators (arithmetic, etc.)

* Extend to all other ops except Tuple, which is tricky (not quite sure where to get the needed type info)

* Some tidying(???) up

* Extend to handle Tuple op

* Slightly more descriptive macro name
parent 7d11d579
......@@ -51,6 +51,11 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const element::Type& get_element_type() const
{
return m_tensor_view.get_tensor_view_type()->get_element_type();
}
const Shape& get_shape() const
{
return m_tensor_view.get_tensor_view_type()->get_shape();
......
......@@ -102,19 +102,214 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
const std::vector<TensorViewInfo>& in, \
const std::vector<TensorViewInfo>& out)
// Suppress Clang's complaints about the ,##__VA_ARGS__ token-pasting hack, which is a GNU extension
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#define DO_ON_ELEMENT_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Bool::element_type()) \
{ \
macro(element::Bool, ##__VA_ARGS__); \
} \
else if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else if (et == element::UInt8::element_type()) \
{ \
macro(element::UInt8, ##__VA_ARGS__); \
} \
else if (et == element::UInt32::element_type()) \
{ \
macro(element::UInt32, ##__VA_ARGS__); \
} \
else if (et == element::UInt64::element_type()) \
{ \
macro(element::UInt64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define DO_ON_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else if (et == element::UInt8::element_type()) \
{ \
macro(element::UInt8, ##__VA_ARGS__); \
} \
else if (et == element::UInt32::element_type()) \
{ \
macro(element::UInt32, ##__VA_ARGS__); \
} \
else if (et == element::UInt64::element_type()) \
{ \
macro(element::UInt64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define DO_ON_SIGNED_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \
{ \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
}
// Versions the include the descriptor
#define REGISTER_UNOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], out[0])
#define REGISTER_BINOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], out[0])
#define REGISTER_TERNOP(op_class, instr_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0])
#define M_REGISTER_SIGNED_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_SIGNED_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_SIGNED_NUMERIC_TYPE( \
et, \
"Internal error: signed numeric unop has unhandled element type", \
M_REGISTER_SIGNED_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric unop has unhandled element type", \
M_REGISTER_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_BINOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], out[0]));
#define REGISTER_NUMERIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric binop has unhandled element type", \
M_REGISTER_NUMERIC_BINOP, \
instr_class); \
}
#define M_REGISTER_POLYMORPHIC_BINOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], out[0]));
#define REGISTER_POLYMORPHIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic binop has unhandled element type", \
M_REGISTER_POLYMORPHIC_BINOP, \
instr_class); \
}
// Something sneaky here: note the at(1) instead of at(0).
#define M_REGISTER_POLYMORPHIC_TERNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], in[2], out[0]));
#define REGISTER_POLYMORPHIC_TERNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(1)->get_value_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic ternop has unhandled element type", \
M_REGISTER_POLYMORPHIC_TERNOP, \
instr_class); \
}
#define REGISTER_CONSTANT_INSTRUCTIONS(T) \
{ \
REGISTER_INSTRUCTION( \
op::ScalarConstant<T>, \
runtime::eigen::ConstantInstruction<T>, \
std::vector<T::type>{dynamic_cast<const op::ScalarConstant<T>*>(n)->get_value()}, \
out[0]); \
REGISTER_INSTRUCTION( \
op::TensorConstant<T>, \
runtime::eigen::ConstantInstruction<T>, \
std::vector<T::type>{ \
dynamic_cast<const op::TensorConstant<T>*>(n)->get_value()->get_vector()}, \
out[0]); \
}
#define PUSH_INSTRUCTION(T, instr, ...) \
{ \
ef->get_instructions()->push_back(make_shared<instr<T>>(__VA_ARGS__)); \
}
#define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
#define PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_NUMERIC_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
// Turn off complaint suppression (see above)
#pragma clang diagnostic pop
// Define code generators for handled ops.
ExternalFunction::OpMap& ExternalFunction::get_op_map()
......@@ -123,34 +318,34 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
static OpMap op_map;
if (!initialized)
{
REGISTER_UNOP(op::Abs, runtime::eigen::AbsInstruction<element::Float32>);
REGISTER_BINOP(op::Add, runtime::eigen::AddInstruction<element::Float32>);
REGISTER_BINOP(op::Divide, runtime::eigen::DivideInstruction<element::Float32>);
REGISTER_BINOP(op::Equal, runtime::eigen::EqualInstruction<element::Float32>);
REGISTER_BINOP(op::Greater, runtime::eigen::GreaterThanInstruction<element::Float32>);
REGISTER_BINOP(op::GreaterEq, runtime::eigen::GreaterEqInstruction<element::Float32>);
REGISTER_BINOP(op::Less, runtime::eigen::LessThanInstruction<element::Float32>);
REGISTER_BINOP(op::LessEq, runtime::eigen::LessEqInstruction<element::Float32>);
REGISTER_UNOP(op::Log, runtime::eigen::LogInstruction<element::Float32>);
REGISTER_BINOP(op::Maximum, runtime::eigen::MaximumInstruction<element::Float32>);
REGISTER_BINOP(op::Multiply, runtime::eigen::MultiplyInstruction<element::Float32>);
REGISTER_UNOP(op::Negative, runtime::eigen::NegateInstruction<element::Float32>);
REGISTER_BINOP(op::NotEqual, runtime::eigen::NotEqualInstruction<element::Float32>);
REGISTER_TERNOP(op::Select, runtime::eigen::SelectInstruction<element::Float32>);
REGISTER_BINOP(op::Subtract, runtime::eigen::SubtractInstruction<element::Float32>);
REGISTER_INSTRUCTION(
op::ScalarConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>,
std::vector<element::Float32::type>{
dynamic_cast<const op::ScalarConstant<element::Float32>*>(n)->get_value()},
out[0]);
REGISTER_INSTRUCTION(
op::TensorConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>,
dynamic_cast<const op::TensorConstant<element::Float32>*>(n)->get_value()->get_vector(),
out[0]);
REGISTER_NUMERIC_UNOP(op::Log, runtime::eigen::LogInstruction);
REGISTER_NUMERIC_UNOP(op::Negative, runtime::eigen::NegateInstruction);
REGISTER_SIGNED_NUMERIC_UNOP(op::Abs, runtime::eigen::AbsInstruction);
REGISTER_NUMERIC_BINOP(op::Add, runtime::eigen::AddInstruction);
REGISTER_NUMERIC_BINOP(op::Divide, runtime::eigen::DivideInstruction);
REGISTER_NUMERIC_BINOP(op::Greater, runtime::eigen::GreaterThanInstruction);
REGISTER_NUMERIC_BINOP(op::GreaterEq, runtime::eigen::GreaterEqInstruction);
REGISTER_NUMERIC_BINOP(op::Less, runtime::eigen::LessThanInstruction);
REGISTER_NUMERIC_BINOP(op::LessEq, runtime::eigen::LessEqInstruction);
REGISTER_NUMERIC_BINOP(op::Maximum, runtime::eigen::MaximumInstruction);
REGISTER_NUMERIC_BINOP(op::Multiply, runtime::eigen::MultiplyInstruction);
REGISTER_NUMERIC_BINOP(op::Subtract, runtime::eigen::SubtractInstruction);
REGISTER_POLYMORPHIC_BINOP(op::Equal, runtime::eigen::EqualInstruction);
REGISTER_POLYMORPHIC_BINOP(op::NotEqual, runtime::eigen::NotEqualInstruction);
REGISTER_POLYMORPHIC_TERNOP(op::Select, runtime::eigen::SelectInstruction);
REGISTER_CONSTANT_INSTRUCTIONS(element::Bool);
REGISTER_CONSTANT_INSTRUCTIONS(element::Float32);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int8);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int32);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int64);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt8);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt32);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt64);
REGISTER_TO_OP_MAP(op::Broadcast)
{
......@@ -166,40 +361,46 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (broadcast->get_broadcast_axes().empty())
{
// Degenerate case: no broadcast axes is just a copy.
ef->get_instructions()->push_back(
make_shared<runtime::eigen::CopyInstruction<element::Float32>>(
in[0].get_index(), out[0].get_index()));
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
runtime::eigen::CopyInstruction,
in[0].get_index(),
out[0].get_index());
}
else if (arg_shape.size() == 0)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::BroadcastScalarInstruction<element::Float32>>(
in[0], out[0]));
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
runtime::eigen::BroadcastScalarInstruction,
in[0],
out[0]);
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
ef->get_instructions()->push_back(
make_shared<
runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>(
in[0], out[0]));
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
runtime::eigen::BroadcastVectorColwiseInstruction,
in[0],
out[0]);
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
ef->get_instructions()->push_back(
make_shared<
runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>(
in[0], out[0]));
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
runtime::eigen::BroadcastVectorRowwiseInstruction,
in[0],
out[0]);
}
else
{
throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} or "
"Internal error: axis set for vector-matrix broadcast is neither {0} nor "
"{1}");
}
}
......@@ -216,20 +417,25 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (result_shape.size() == 1)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(in,
out[0]));
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Concat has unhandled element type",
runtime::eigen::ConcatVectorInstruction,
in,
out[0]);
}
else if (result_shape.size() == 2)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>(
in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]));
PUSH_POLYMORPHIC_INSTRUCTION(
result_element_type,
"Concat has unhandled element type",
runtime::eigen::ConcatMatrixInstruction,
in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]);
}
else
{
......@@ -253,44 +459,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg0_shape = arg0_tensor_type->get_shape();
auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product.
if (arg0_shape.size() == 0)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>(
in[0], in[1], out[0]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
runtime::eigen::ScalarTensorProductInstruction,
in[0],
in[1],
out[0]);
}
else if (arg1_shape.size() == 0)
{
// If arg1 is the scalar, do the same thing but switch the order of operands.
ef->get_instructions()->push_back(
make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>(
in[1], in[0], out[0]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
runtime::eigen::ScalarTensorProductInstruction,
in[1],
in[0],
out[0]);
}
// If arg0 and arg1 are both vectors, emit a dot product.
else if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::DotInstruction<element::Float32>>(
in[0], in[1], out[0]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
runtime::eigen::DotInstruction,
in[0],
in[1],
out[0]);
}
// If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::MatrixVectorProductInstruction<element::Float32>>(
in[0], in[1], out[0]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
runtime::eigen::MatrixVectorProductInstruction,
in[0],
in[1],
out[0]);
}
// If arg0 and arg1 are both matrices, emit a matrix product.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 2)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::MatrixMultInstruction<element::Float32>>(
in[0], in[1], out[0]));
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
runtime::eigen::MatrixMultInstruction,
in[0],
in[1],
out[0]);
}
else
......@@ -307,9 +528,17 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
auto get_tuple_element = static_cast<const op::GetTupleElement*>(n);
ef->get_instructions()->push_back(
make_shared<runtime::eigen::CopyInstruction<element::Float32>>(
in.at(get_tuple_element->get_n()).get_index(), out.at(0).get_index()));
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"GetTupleElement has unhandled element type",
runtime::eigen::CopyInstruction,
in.at(get_tuple_element->get_n()).get_index(),
out.at(0).get_index());
};
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
......@@ -317,9 +546,12 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
for (size_t i = 0; i < in.size(); ++i)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::CopyInstruction<element::Float32>>(
in.at(i).get_index(), out.at(i).get_index()));
auto& et = in.at(i).get_tensor_view_layout()->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(et,
"Tuple has unhandled element type",
runtime::eigen::CopyInstruction,
in.at(i).get_index(),
out.at(i).get_index());
}
};
......@@ -467,8 +699,13 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame(Functio
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> temps;
for (auto tv : m_temp_views)
{
temps.push_back(ngraph::runtime::make_tensor<ngraph::element::Float32>(
tv->get_tensor_view_type()->get_shape()));
auto& et = tv->get_tensor_view_type()->get_element_type();
auto shape = tv->get_tensor_view_type()->get_shape();
#define M(T) temps.push_back(ngraph::runtime::make_tensor<T>(shape));
DO_ON_ELEMENT_TYPE(
et, "Internal error: tried to create temporary for unhandled element type", M);
#undef M
}
return make_shared<ngraph::runtime::CallFrame>(
m_n_inputs, m_n_outputs, temps, 0, m_instructions);
......
......@@ -50,6 +50,37 @@ TEST(execute, test_abc)
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
}
TEST(execute, test_abc_int64)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Int64::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Int64::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt, op::Parameters{A, B, C});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape);
*a = vector<element::Int64::type>{1, 2, 3, 4};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape);
*b = vector<element::Int64::type>{5, 6, 7, 8};
auto c = ngraph::runtime::make_tensor<element::Int64>(shape);
*c = vector<element::Int64::type>{9, 10, 11, 12};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape);
(*cf)({a, b, c}, {result});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({b, a, c}, {result});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({a, c, b}, {result});
ASSERT_EQ((vector<element::Int64::type>{50, 72, 98, 128}), result->get_vector());
}
// Same as test_abc, but using tuples for input and output
TEST(execute, test_abc_tuple)
{
......@@ -92,6 +123,48 @@ TEST(execute, test_abc_tuple)
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
}
// Same as test_abc, but using tuples for input and output
TEST(execute, test_abc_tuple_int64)
{
auto shape = Shape{2, 2};
auto tensor_view_type = make_shared<TensorViewType>(element::Int64::element_type(), shape);
auto ABC = make_shared<op::Parameter>(
make_shared<TupleType>(ValueTypes{tensor_view_type, tensor_view_type, tensor_view_type}));
auto A = make_shared<op::GetTupleElement>(ABC, 0);
auto B = make_shared<op::GetTupleElement>(ABC, 1);
auto C = make_shared<op::GetTupleElement>(ABC, 2);
auto f = make_shared<Function>(
make_shared<op::Tuple>(Nodes{(A + B) * C}), tensor_view_type, op::Parameters{ABC});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape);
*a = vector<element::Int64::type>{1, 2, 3, 4};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape);
*b = vector<element::Int64::type>{5, 6, 7, 8};
auto c = ngraph::runtime::make_tensor<element::Int64>(shape);
*c = vector<element::Int64::type>{9, 10, 11, 12};
auto abc = ngraph::runtime::make_tuple({a, b, c});
auto bac = ngraph::runtime::make_tuple({b, a, c});
auto acb = ngraph::runtime::make_tuple({a, c, b});
auto result = ngraph::runtime::make_tensor<element::Int64>(shape);
auto result_tuple = ngraph::runtime::make_tuple({result});
(*cf)({abc}, {result_tuple});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({bac}, {result_tuple});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({acb}, {result_tuple});
ASSERT_EQ((vector<element::Int64::type>{50, 72, 98, 128}), result->get_vector());
}
// Multiple retrive values
TEST(execute, test_tuple_result)
{
......@@ -206,6 +279,36 @@ TEST(execute, test_concat_matrix_rowwise)
result->get_vector());
}
TEST(execute, test_concat_matrix_int64)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
auto shape_b = Shape{3, 2};
auto B = make_shared<op::Parameter>(element::Int64::element_type(), shape_b);
auto shape_c = Shape{3, 2};
auto C = make_shared<op::Parameter>(element::Int64::element_type(), shape_c);
auto shape_r = Shape{8, 2};
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), Shape{8, 2});
auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape_a);
*a = vector<element::Int64::type>{2, 4, 8, 16};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape_b);
*b = vector<element::Int64::type>{1, 2, 4, 8, 16, 32};
auto c = ngraph::runtime::make_tensor<element::Int64>(shape_c);
*c = vector<element::Int64::type>{2, 3, 5, 7, 11, 13};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape_r);
(*cf)({a, b, c}, {result});
ASSERT_EQ((vector<element::Int64::type>{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 2, 3, 5, 7, 11, 13}),
result->get_vector());
}
TEST(execute, test_concat_vector)
{
auto shape_a = Shape{4};
......@@ -560,6 +663,30 @@ TEST(execute, test_dot_matrix_vector)
ASSERT_EQ((vector<float>{190, 486, 782, 1078}), result->get_vector());
}
TEST(execute, test_dot_matrix_vector_int64)
{
auto shape_a = Shape{4, 4};
auto shape_b = Shape{4};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
auto B = make_shared<op::Parameter>(element::Int64::element_type(), shape_b);
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape_b);
auto f = make_shared<Function>(make_shared<op::Dot>(A, B), rt, op::Parameters{A, B});
auto shape_r = Shape{4};
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape_a);
*a = vector<element::Int64::type>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape_b);
*b = vector<element::Int64::type>{17, 18, 19, 20};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape_r);
(*cf)({a, b}, {result});
ASSERT_EQ((vector<element::Int64::type>{190, 486, 782, 1078}), result->get_vector());
}
TEST(execute, test_greater)
{
auto shape = Shape{2, 2, 2};
......@@ -1001,3 +1128,25 @@ TEST(execute, test_broadcast_vector_rowwise)
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), result->get_vector());
}
TEST(execute, test_broadcast_vector_rowwise_int64)
{
auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
auto shape_r = Shape{3, 4};
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape_a);
*a = vector<element::Int64::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Int64::type>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}),
result->get_vector());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment