Commit 4ecdb791 authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/codegen

parents d185b48c 65aeb4b5
...@@ -51,6 +51,11 @@ namespace ngraph ...@@ -51,6 +51,11 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t. /// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0; virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const element::Type& get_element_type() const
{
return m_tensor_view.get_tensor_view_type()->get_element_type();
}
const Shape& get_shape() const const Shape& get_shape() const
{ {
return m_tensor_view.get_tensor_view_type()->get_shape(); return m_tensor_view.get_tensor_view_type()->get_shape();
......
...@@ -17,9 +17,10 @@ ...@@ -17,9 +17,10 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
void Convert::propagate_types() const element::Type& Convert::propagate_element_types(const element::Type& arg_element_type) const
{ {
throw ngraph_error("NIY"); return m_element_type;
} }
...@@ -27,9 +27,9 @@ namespace ngraph ...@@ -27,9 +27,9 @@ namespace ngraph
{ {
} }
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
virtual std::string description() const override { return "Convert"; } virtual std::string description() const override { return "Convert"; }
virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction
{
public:
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ETO>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg).template cast<typename ETO::type>();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp" #include "ngraph/ops/equal.hpp"
...@@ -59,6 +60,7 @@ ...@@ -59,6 +60,7 @@
#include "ngraph/runtime/eigen/concat_matrix.hpp" #include "ngraph/runtime/eigen/concat_matrix.hpp"
#include "ngraph/runtime/eigen/concat_vector.hpp" #include "ngraph/runtime/eigen/concat_vector.hpp"
#include "ngraph/runtime/eigen/constant.hpp" #include "ngraph/runtime/eigen/constant.hpp"
#include "ngraph/runtime/eigen/convert.hpp"
#include "ngraph/runtime/eigen/copy.hpp" #include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/divide.hpp" #include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/dot.hpp" #include "ngraph/runtime/eigen/dot.hpp"
...@@ -102,19 +104,214 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -102,19 +104,214 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
const std::vector<TensorViewInfo>& in, \ const std::vector<TensorViewInfo>& in, \
const std::vector<TensorViewInfo>& out) const std::vector<TensorViewInfo>& out)
// Suppress Clang's complaints about the ,##__VA_ARGS__ token-pasting hack, which is a GNU extension
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
#define DO_ON_ELEMENT_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Bool::element_type()) \
{ \
macro(element::Bool, ##__VA_ARGS__); \
} \
else if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else if (et == element::UInt8::element_type()) \
{ \
macro(element::UInt8, ##__VA_ARGS__); \
} \
else if (et == element::UInt32::element_type()) \
{ \
macro(element::UInt32, ##__VA_ARGS__); \
} \
else if (et == element::UInt64::element_type()) \
{ \
macro(element::UInt64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define DO_ON_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else if (et == element::UInt8::element_type()) \
{ \
macro(element::UInt8, ##__VA_ARGS__); \
} \
else if (et == element::UInt32::element_type()) \
{ \
macro(element::UInt32, ##__VA_ARGS__); \
} \
else if (et == element::UInt64::element_type()) \
{ \
macro(element::UInt64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define DO_ON_SIGNED_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \ #define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \ ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
} }
// Versions the include the descriptor #define M_REGISTER_SIGNED_NUMERIC_UNOP(T, instr_class) \
#define REGISTER_UNOP(op_class, instr_class) \ ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
REGISTER_INSTRUCTION(op_class, instr_class, in[0], out[0]) #define REGISTER_SIGNED_NUMERIC_UNOP(op_class, instr_class) \
#define REGISTER_BINOP(op_class, instr_class) \ REGISTER_TO_OP_MAP(op_class) \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], out[0]) { \
#define REGISTER_TERNOP(op_class, instr_class) \ const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0]) n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_SIGNED_NUMERIC_TYPE( \
et, \
"Internal error: signed numeric unop has unhandled element type", \
M_REGISTER_SIGNED_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric unop has unhandled element type", \
M_REGISTER_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_BINOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], out[0]));
#define REGISTER_NUMERIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric binop has unhandled element type", \
M_REGISTER_NUMERIC_BINOP, \
instr_class); \
}
#define M_REGISTER_POLYMORPHIC_BINOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], out[0]));
#define REGISTER_POLYMORPHIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic binop has unhandled element type", \
M_REGISTER_POLYMORPHIC_BINOP, \
instr_class); \
}
// Something sneaky here: note the at(1) instead of at(0).
#define M_REGISTER_POLYMORPHIC_TERNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], in[1], in[2], out[0]));
#define REGISTER_POLYMORPHIC_TERNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(1)->get_value_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic ternop has unhandled element type", \
M_REGISTER_POLYMORPHIC_TERNOP, \
instr_class); \
}
#define REGISTER_CONSTANT_INSTRUCTIONS(T) \
{ \
REGISTER_INSTRUCTION( \
op::ScalarConstant<T>, \
runtime::eigen::ConstantInstruction<T>, \
std::vector<T::type>{dynamic_cast<const op::ScalarConstant<T>*>(n)->get_value()}, \
out[0]); \
REGISTER_INSTRUCTION( \
op::TensorConstant<T>, \
runtime::eigen::ConstantInstruction<T>, \
std::vector<T::type>{ \
dynamic_cast<const op::TensorConstant<T>*>(n)->get_value()->get_vector()}, \
out[0]); \
}
#define PUSH_INSTRUCTION(T, instr, ...) \
{ \
ef->get_instructions()->push_back(make_shared<instr<T>>(__VA_ARGS__)); \
}
#define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
#define PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_NUMERIC_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
// Turn off complaint suppression (see above)
#pragma clang diagnostic pop
// Define code generators for handled ops. // Define code generators for handled ops.
ExternalFunction::OpMap& ExternalFunction::get_op_map() ExternalFunction::OpMap& ExternalFunction::get_op_map()
...@@ -123,34 +320,34 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -123,34 +320,34 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
static OpMap op_map; static OpMap op_map;
if (!initialized) if (!initialized)
{ {
REGISTER_UNOP(op::Abs, runtime::eigen::AbsInstruction<element::Float32>); REGISTER_NUMERIC_UNOP(op::Log, runtime::eigen::LogInstruction);
REGISTER_BINOP(op::Add, runtime::eigen::AddInstruction<element::Float32>); REGISTER_NUMERIC_UNOP(op::Negative, runtime::eigen::NegateInstruction);
REGISTER_BINOP(op::Divide, runtime::eigen::DivideInstruction<element::Float32>);
REGISTER_BINOP(op::Equal, runtime::eigen::EqualInstruction<element::Float32>); REGISTER_SIGNED_NUMERIC_UNOP(op::Abs, runtime::eigen::AbsInstruction);
REGISTER_BINOP(op::Greater, runtime::eigen::GreaterThanInstruction<element::Float32>);
REGISTER_BINOP(op::GreaterEq, runtime::eigen::GreaterEqInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Add, runtime::eigen::AddInstruction);
REGISTER_BINOP(op::Less, runtime::eigen::LessThanInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Divide, runtime::eigen::DivideInstruction);
REGISTER_BINOP(op::LessEq, runtime::eigen::LessEqInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Greater, runtime::eigen::GreaterThanInstruction);
REGISTER_UNOP(op::Log, runtime::eigen::LogInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::GreaterEq, runtime::eigen::GreaterEqInstruction);
REGISTER_BINOP(op::Maximum, runtime::eigen::MaximumInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Less, runtime::eigen::LessThanInstruction);
REGISTER_BINOP(op::Multiply, runtime::eigen::MultiplyInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::LessEq, runtime::eigen::LessEqInstruction);
REGISTER_UNOP(op::Negative, runtime::eigen::NegateInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Maximum, runtime::eigen::MaximumInstruction);
REGISTER_BINOP(op::NotEqual, runtime::eigen::NotEqualInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Multiply, runtime::eigen::MultiplyInstruction);
REGISTER_TERNOP(op::Select, runtime::eigen::SelectInstruction<element::Float32>); REGISTER_NUMERIC_BINOP(op::Subtract, runtime::eigen::SubtractInstruction);
REGISTER_BINOP(op::Subtract, runtime::eigen::SubtractInstruction<element::Float32>);
REGISTER_POLYMORPHIC_BINOP(op::Equal, runtime::eigen::EqualInstruction);
REGISTER_INSTRUCTION( REGISTER_POLYMORPHIC_BINOP(op::NotEqual, runtime::eigen::NotEqualInstruction);
op::ScalarConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>, REGISTER_POLYMORPHIC_TERNOP(op::Select, runtime::eigen::SelectInstruction);
std::vector<element::Float32::type>{
dynamic_cast<const op::ScalarConstant<element::Float32>*>(n)->get_value()}, REGISTER_CONSTANT_INSTRUCTIONS(element::Bool);
out[0]); REGISTER_CONSTANT_INSTRUCTIONS(element::Float32);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int8);
REGISTER_INSTRUCTION( REGISTER_CONSTANT_INSTRUCTIONS(element::Int32);
op::TensorConstant<element::Float32>, REGISTER_CONSTANT_INSTRUCTIONS(element::Int64);
runtime::eigen::ConstantInstruction<element::Float32>, REGISTER_CONSTANT_INSTRUCTIONS(element::UInt8);
dynamic_cast<const op::TensorConstant<element::Float32>*>(n)->get_value()->get_vector(), REGISTER_CONSTANT_INSTRUCTIONS(element::UInt32);
out[0]); REGISTER_CONSTANT_INSTRUCTIONS(element::UInt64);
REGISTER_TO_OP_MAP(op::Broadcast) REGISTER_TO_OP_MAP(op::Broadcast)
{ {
...@@ -166,40 +363,46 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -166,40 +363,46 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg_shape = arg_tensor_type->get_shape(); auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape(); auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (broadcast->get_broadcast_axes().empty()) if (broadcast->get_broadcast_axes().empty())
{ {
// Degenerate case: no broadcast axes is just a copy. PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
ef->get_instructions()->push_back( "Broadcast has unhandled element type",
make_shared<runtime::eigen::CopyInstruction<element::Float32>>( runtime::eigen::CopyInstruction,
in[0].get_index(), out[0].get_index())); in[0].get_index(),
out[0].get_index());
} }
else if (arg_shape.size() == 0) else if (arg_shape.size() == 0)
{ {
ef->get_instructions()->push_back( PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
make_shared<runtime::eigen::BroadcastScalarInstruction<element::Float32>>( "Broadcast has unhandled element type",
in[0], out[0])); runtime::eigen::BroadcastScalarInstruction,
in[0],
out[0]);
} }
else if (arg_shape.size() == 1 && result_shape.size() == 2) else if (arg_shape.size() == 1 && result_shape.size() == 2)
{ {
if (broadcast->get_broadcast_axes() == AxisSet{1}) if (broadcast->get_broadcast_axes() == AxisSet{1})
{ {
ef->get_instructions()->push_back( PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
make_shared< "Broadcast has unhandled element type",
runtime::eigen::BroadcastVectorColwiseInstruction<element::Float32>>( runtime::eigen::BroadcastVectorColwiseInstruction,
in[0], out[0])); in[0],
out[0]);
} }
else if (broadcast->get_broadcast_axes() == AxisSet{0}) else if (broadcast->get_broadcast_axes() == AxisSet{0})
{ {
ef->get_instructions()->push_back( PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
make_shared< "Broadcast has unhandled element type",
runtime::eigen::BroadcastVectorRowwiseInstruction<element::Float32>>( runtime::eigen::BroadcastVectorRowwiseInstruction,
in[0], out[0])); in[0],
out[0]);
} }
else else
{ {
throw ngraph_error( throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} or " "Internal error: axis set for vector-matrix broadcast is neither {0} nor "
"{1}"); "{1}");
} }
} }
...@@ -216,20 +419,25 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -216,20 +419,25 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
assert(nullptr != result_tensor_type); assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape(); auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (result_shape.size() == 1) if (result_shape.size() == 1)
{ {
ef->get_instructions()->push_back( PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(in, "Concat has unhandled element type",
out[0])); runtime::eigen::ConcatVectorInstruction,
in,
out[0]);
} }
else if (result_shape.size() == 2) else if (result_shape.size() == 2)
{ {
ef->get_instructions()->push_back( PUSH_POLYMORPHIC_INSTRUCTION(
make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>( result_element_type,
in, "Concat has unhandled element type",
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(), runtime::eigen::ConcatMatrixInstruction,
out[0])); in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]);
} }
else else
{ {
...@@ -237,6 +445,62 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -237,6 +445,62 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
} }
}; };
REGISTER_TO_OP_MAP(op::Convert)
{
auto arg = n->get_arguments().at(0);
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(arg->get_value_type());
assert(nullptr != arg_tensor_type);
auto& arg_element_type = arg_tensor_type->get_element_type();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
// Hacky macro: we are going to be building up a series of else-ifs for each possible
// pair of element types.
#define REGISTER_CONVERT(TI, TO) \
else if (arg_element_type == (TI::element_type()) && \
result_element_type == (TO::element_type())) \
{ \
ef->get_instructions()->push_back( \
make_shared<runtime::eigen::ConvertInstruction<TI, TO>>(in[0], out[0])); \
}
// End hacky macro
// Hacky macro: Given some type TI, generate the else-ifs for TI to every other element
// type.
#define REGISTER_CONVERTS(TI) \
REGISTER_CONVERT(TI, element::Bool) \
REGISTER_CONVERT(TI, element::Float32) \
REGISTER_CONVERT(TI, element::Int8) \
REGISTER_CONVERT(TI, element::Int32) \
REGISTER_CONVERT(TI, element::Int64) \
REGISTER_CONVERT(TI, element::UInt8) \
REGISTER_CONVERT(TI, element::UInt32) \
REGISTER_CONVERT(TI, element::UInt64)
// End hacky macro
if (false)
{
}
REGISTER_CONVERTS(element::Bool)
REGISTER_CONVERTS(element::Float32)
REGISTER_CONVERTS(element::Int8)
REGISTER_CONVERTS(element::Int32)
REGISTER_CONVERTS(element::Int64)
REGISTER_CONVERTS(element::UInt8)
REGISTER_CONVERTS(element::UInt32)
REGISTER_CONVERTS(element::UInt64)
else { throw ngraph_error("Internal error: cannot convert between element types"); }
#undef REGISTER_CONVERTS
#undef REGISTER_CONVERT
};
REGISTER_TO_OP_MAP(op::Dot) REGISTER_TO_OP_MAP(op::Dot)
{ {
auto& arg_nodes = n->get_arguments(); auto& arg_nodes = n->get_arguments();
...@@ -253,44 +517,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -253,44 +517,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg0_shape = arg0_tensor_type->get_shape(); auto arg0_shape = arg0_tensor_type->get_shape();
auto arg1_shape = arg1_tensor_type->get_shape(); auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product. // If arg0 or arg1 is a scalar, emit a scalar-tensor product.
if (arg0_shape.size() == 0) if (arg0_shape.size() == 0)
{ {
ef->get_instructions()->push_back( PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>( "Dot has unhandled element type",
in[0], in[1], out[0])); runtime::eigen::ScalarTensorProductInstruction,
in[0],
in[1],
out[0]);
} }
else if (arg1_shape.size() == 0) else if (arg1_shape.size() == 0)
{ {
// If arg1 is the scalar, do the same thing but switch the order of operands. PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
ef->get_instructions()->push_back( "Dot has unhandled element type",
make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>( runtime::eigen::ScalarTensorProductInstruction,
in[1], in[0], out[0])); in[1],
in[0],
out[0]);
} }
// If arg0 and arg1 are both vectors, emit a dot product. // If arg0 and arg1 are both vectors, emit a dot product.
else if (arg0_shape.size() == 1 && arg1_shape.size() == 1) else if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{ {
ef->get_instructions()->push_back( PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
make_shared<runtime::eigen::DotInstruction<element::Float32>>( "Dot has unhandled element type",
in[0], in[1], out[0])); runtime::eigen::DotInstruction,
in[0],
in[1],
out[0]);
} }
// If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product. // If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1) else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{ {
ef->get_instructions()->push_back( PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
make_shared<runtime::eigen::MatrixVectorProductInstruction<element::Float32>>( "Dot has unhandled element type",
in[0], in[1], out[0])); runtime::eigen::MatrixVectorProductInstruction,
in[0],
in[1],
out[0]);
} }
// If arg0 and arg1 are both matrices, emit a matrix product. // If arg0 and arg1 are both matrices, emit a matrix product.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 2) else if (arg0_shape.size() == 2 && arg1_shape.size() == 2)
{ {
ef->get_instructions()->push_back( PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
make_shared<runtime::eigen::MatrixMultInstruction<element::Float32>>( "Dot has unhandled element type",
in[0], in[1], out[0])); runtime::eigen::MatrixMultInstruction,
in[0],
in[1],
out[0]);
} }
else else
...@@ -307,9 +586,17 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -307,9 +586,17 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto get_tuple_element = static_cast<const op::GetTupleElement*>(n); auto get_tuple_element = static_cast<const op::GetTupleElement*>(n);
ef->get_instructions()->push_back( auto result_tensor_type =
make_shared<runtime::eigen::CopyInstruction<element::Float32>>( dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
in.at(get_tuple_element->get_n()).get_index(), out.at(0).get_index())); assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"GetTupleElement has unhandled element type",
runtime::eigen::CopyInstruction,
in.at(get_tuple_element->get_n()).get_index(),
out.at(0).get_index());
}; };
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy. // Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
...@@ -317,9 +604,12 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -317,9 +604,12 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
for (size_t i = 0; i < in.size(); ++i) for (size_t i = 0; i < in.size(); ++i)
{ {
ef->get_instructions()->push_back( auto& et = in.at(i).get_tensor_view_layout()->get_element_type();
make_shared<runtime::eigen::CopyInstruction<element::Float32>>( PUSH_POLYMORPHIC_INSTRUCTION(et,
in.at(i).get_index(), out.at(i).get_index())); "Tuple has unhandled element type",
runtime::eigen::CopyInstruction,
in.at(i).get_index(),
out.at(i).get_index());
} }
}; };
...@@ -467,8 +757,13 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame(Functio ...@@ -467,8 +757,13 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame(Functio
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> temps; std::vector<std::shared_ptr<ngraph::runtime::TensorView>> temps;
for (auto tv : m_temp_views) for (auto tv : m_temp_views)
{ {
temps.push_back(ngraph::runtime::make_tensor<ngraph::element::Float32>( auto& et = tv->get_tensor_view_type()->get_element_type();
tv->get_tensor_view_type()->get_shape())); auto shape = tv->get_tensor_view_type()->get_shape();
#define M(T) temps.push_back(ngraph::runtime::make_tensor<T>(shape));
DO_ON_ELEMENT_TYPE(
et, "Internal error: tried to create temporary for unhandled element type", M);
#undef M
} }
return make_shared<ngraph::runtime::CallFrame>( return make_shared<ngraph::runtime::CallFrame>(
m_n_inputs, m_n_outputs, temps, 0, m_instructions); m_n_inputs, m_n_outputs, temps, 0, m_instructions);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <cmath>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
...@@ -50,6 +51,37 @@ TEST(execute, test_abc) ...@@ -50,6 +51,37 @@ TEST(execute, test_abc)
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector()); ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
} }
TEST(execute, test_abc_int64)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Int64::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Int64::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt, op::Parameters{A, B, C});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape);
*a = vector<element::Int64::type>{1, 2, 3, 4};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape);
*b = vector<element::Int64::type>{5, 6, 7, 8};
auto c = ngraph::runtime::make_tensor<element::Int64>(shape);
*c = vector<element::Int64::type>{9, 10, 11, 12};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape);
(*cf)({a, b, c}, {result});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({b, a, c}, {result});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({a, c, b}, {result});
ASSERT_EQ((vector<element::Int64::type>{50, 72, 98, 128}), result->get_vector());
}
// Same as test_abc, but using tuples for input and output // Same as test_abc, but using tuples for input and output
TEST(execute, test_abc_tuple) TEST(execute, test_abc_tuple)
{ {
...@@ -92,6 +124,48 @@ TEST(execute, test_abc_tuple) ...@@ -92,6 +124,48 @@ TEST(execute, test_abc_tuple)
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector()); ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
} }
// Same as test_abc, but using tuples for input and output
TEST(execute, test_abc_tuple_int64)
{
auto shape = Shape{2, 2};
auto tensor_view_type = make_shared<TensorViewType>(element::Int64::element_type(), shape);
auto ABC = make_shared<op::Parameter>(
make_shared<TupleType>(ValueTypes{tensor_view_type, tensor_view_type, tensor_view_type}));
auto A = make_shared<op::GetTupleElement>(ABC, 0);
auto B = make_shared<op::GetTupleElement>(ABC, 1);
auto C = make_shared<op::GetTupleElement>(ABC, 2);
auto f = make_shared<Function>(
make_shared<op::Tuple>(Nodes{(A + B) * C}), tensor_view_type, op::Parameters{ABC});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape);
*a = vector<element::Int64::type>{1, 2, 3, 4};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape);
*b = vector<element::Int64::type>{5, 6, 7, 8};
auto c = ngraph::runtime::make_tensor<element::Int64>(shape);
*c = vector<element::Int64::type>{9, 10, 11, 12};
auto abc = ngraph::runtime::make_tuple({a, b, c});
auto bac = ngraph::runtime::make_tuple({b, a, c});
auto acb = ngraph::runtime::make_tuple({a, c, b});
auto result = ngraph::runtime::make_tensor<element::Int64>(shape);
auto result_tuple = ngraph::runtime::make_tuple({result});
(*cf)({abc}, {result_tuple});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({bac}, {result_tuple});
ASSERT_EQ((vector<element::Int64::type>{54, 80, 110, 144}), result->get_vector());
(*cf)({acb}, {result_tuple});
ASSERT_EQ((vector<element::Int64::type>{50, 72, 98, 128}), result->get_vector());
}
// Multiple retrive values // Multiple retrive values
TEST(execute, test_tuple_result) TEST(execute, test_tuple_result)
{ {
...@@ -206,6 +280,36 @@ TEST(execute, test_concat_matrix_rowwise) ...@@ -206,6 +280,36 @@ TEST(execute, test_concat_matrix_rowwise)
result->get_vector()); result->get_vector());
} }
TEST(execute, test_concat_matrix_int64)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
auto shape_b = Shape{3, 2};
auto B = make_shared<op::Parameter>(element::Int64::element_type(), shape_b);
auto shape_c = Shape{3, 2};
auto C = make_shared<op::Parameter>(element::Int64::element_type(), shape_c);
auto shape_r = Shape{8, 2};
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), Shape{8, 2});
auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 0), rt, op::Parameters{A, B, C});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape_a);
*a = vector<element::Int64::type>{2, 4, 8, 16};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape_b);
*b = vector<element::Int64::type>{1, 2, 4, 8, 16, 32};
auto c = ngraph::runtime::make_tensor<element::Int64>(shape_c);
*c = vector<element::Int64::type>{2, 3, 5, 7, 11, 13};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape_r);
(*cf)({a, b, c}, {result});
ASSERT_EQ((vector<element::Int64::type>{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 2, 3, 5, 7, 11, 13}),
result->get_vector());
}
TEST(execute, test_concat_vector) TEST(execute, test_concat_vector)
{ {
auto shape_a = Shape{4}; auto shape_a = Shape{4};
...@@ -560,6 +664,30 @@ TEST(execute, test_dot_matrix_vector) ...@@ -560,6 +664,30 @@ TEST(execute, test_dot_matrix_vector)
ASSERT_EQ((vector<float>{190, 486, 782, 1078}), result->get_vector()); ASSERT_EQ((vector<float>{190, 486, 782, 1078}), result->get_vector());
} }
TEST(execute, test_dot_matrix_vector_int64)
{
auto shape_a = Shape{4, 4};
auto shape_b = Shape{4};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
auto B = make_shared<op::Parameter>(element::Int64::element_type(), shape_b);
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape_b);
auto f = make_shared<Function>(make_shared<op::Dot>(A, B), rt, op::Parameters{A, B});
auto shape_r = Shape{4};
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape_a);
*a = vector<element::Int64::type>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto b = ngraph::runtime::make_tensor<element::Int64>(shape_b);
*b = vector<element::Int64::type>{17, 18, 19, 20};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape_r);
(*cf)({a, b}, {result});
ASSERT_EQ((vector<element::Int64::type>{190, 486, 782, 1078}), result->get_vector());
}
TEST(execute, test_greater) TEST(execute, test_greater)
{ {
auto shape = Shape{2, 2, 2}; auto shape = Shape{2, 2, 2};
...@@ -1001,3 +1129,85 @@ TEST(execute, test_broadcast_vector_rowwise) ...@@ -1001,3 +1129,85 @@ TEST(execute, test_broadcast_vector_rowwise)
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), result->get_vector()); ASSERT_EQ((vector<float>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), result->get_vector());
} }
TEST(execute, test_broadcast_vector_rowwise_int64)
{
auto shape_a = Shape{4};
auto A = make_shared<op::Parameter>(element::Int64::element_type(), shape_a);
auto shape_r = Shape{3, 4};
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int64>(shape_a);
*a = vector<element::Int64::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Int64>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Int64::type>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}),
result->get_vector());
}
TEST(execute, test_convert_int32_float32)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Float32::element_type()), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int32>(shape);
*a = vector<element::Int32::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4}), result->get_vector());
}
TEST(execute, test_convert_int32_bool)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int32>(shape);
*a = vector<element::Int32::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector());
}
TEST(execute, test_convert_float32_bool)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<element::Float32::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector());
}
...@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch) ...@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
} }
} }
// TEST(type_prop, convert_deduce)
// Tests for dot product. {
// // Deduce type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_correct)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 3, 4}));
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 14, 4}));
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_deduce_scalar_2d) TEST(type_prop, dot_deduce_scalar_2d)
{ {
// Deduce type for scalar/matrix arguments // Deduce type for scalar/matrix arguments
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment