Commit 64412792 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

"Convert" operator (#156)

* Type propagation for convert

* Implement "convert" through VM

* Temporary commit to unconfuse git merge

* Undo temporary commit to unconfuse git merge
parent 1a44d7f8
...@@ -17,9 +17,10 @@ ...@@ -17,9 +17,10 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
void Convert::propagate_types() const element::Type& Convert::propagate_element_types(const element::Type& arg_element_type) const
{ {
throw ngraph_error("NIY"); return m_element_type;
} }
...@@ -27,9 +27,9 @@ namespace ngraph ...@@ -27,9 +27,9 @@ namespace ngraph
{ {
} }
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
virtual std::string description() const override { return "Convert"; } virtual std::string description() const override { return "Convert"; }
virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction
{
public:
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ETO>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg).template cast<typename ETO::type>();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp" #include "ngraph/ops/equal.hpp"
...@@ -59,6 +60,7 @@ ...@@ -59,6 +60,7 @@
#include "ngraph/runtime/eigen/concat_matrix.hpp" #include "ngraph/runtime/eigen/concat_matrix.hpp"
#include "ngraph/runtime/eigen/concat_vector.hpp" #include "ngraph/runtime/eigen/concat_vector.hpp"
#include "ngraph/runtime/eigen/constant.hpp" #include "ngraph/runtime/eigen/constant.hpp"
#include "ngraph/runtime/eigen/convert.hpp"
#include "ngraph/runtime/eigen/copy.hpp" #include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/divide.hpp" #include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/dot.hpp" #include "ngraph/runtime/eigen/dot.hpp"
...@@ -443,6 +445,62 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -443,6 +445,62 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
} }
}; };
REGISTER_TO_OP_MAP(op::Convert)
{
auto arg = n->get_arguments().at(0);
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(arg->get_value_type());
assert(nullptr != arg_tensor_type);
auto& arg_element_type = arg_tensor_type->get_element_type();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
// Hacky macro: we are going to be building up a series of else-ifs for each possible
// pair of element types.
#define REGISTER_CONVERT(TI, TO) \
else if (arg_element_type == (TI::element_type()) && \
result_element_type == (TO::element_type())) \
{ \
ef->get_instructions()->push_back( \
make_shared<runtime::eigen::ConvertInstruction<TI, TO>>(in[0], out[0])); \
}
// End hacky macro
// Hacky macro: Given some type TI, generate the else-ifs for TI to every other element
// type.
#define REGISTER_CONVERTS(TI) \
REGISTER_CONVERT(TI, element::Bool) \
REGISTER_CONVERT(TI, element::Float32) \
REGISTER_CONVERT(TI, element::Int8) \
REGISTER_CONVERT(TI, element::Int32) \
REGISTER_CONVERT(TI, element::Int64) \
REGISTER_CONVERT(TI, element::UInt8) \
REGISTER_CONVERT(TI, element::UInt32) \
REGISTER_CONVERT(TI, element::UInt64)
// End hacky macro
if (false)
{
}
REGISTER_CONVERTS(element::Bool)
REGISTER_CONVERTS(element::Float32)
REGISTER_CONVERTS(element::Int8)
REGISTER_CONVERTS(element::Int32)
REGISTER_CONVERTS(element::Int64)
REGISTER_CONVERTS(element::UInt8)
REGISTER_CONVERTS(element::UInt32)
REGISTER_CONVERTS(element::UInt64)
else { throw ngraph_error("Internal error: cannot convert between element types"); }
#undef REGISTER_CONVERTS
#undef REGISTER_CONVERT
};
REGISTER_TO_OP_MAP(op::Dot) REGISTER_TO_OP_MAP(op::Dot)
{ {
auto& arg_nodes = n->get_arguments(); auto& arg_nodes = n->get_arguments();
......
...@@ -1150,3 +1150,63 @@ TEST(execute, test_broadcast_vector_rowwise_int64) ...@@ -1150,3 +1150,63 @@ TEST(execute, test_broadcast_vector_rowwise_int64)
ASSERT_EQ((vector<element::Int64::type>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), ASSERT_EQ((vector<element::Int64::type>{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}),
result->get_vector()); result->get_vector());
} }
TEST(execute, test_convert_int32_float32)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Float32::element_type()), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int32>(shape);
*a = vector<element::Int32::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4}), result->get_vector());
}
TEST(execute, test_convert_int32_bool)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Int32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Int32>(shape);
*a = vector<element::Int32::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector());
}
TEST(execute, test_convert_float32_bool)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(
make_shared<op::Convert>(A, element::Bool::element_type()), rt, op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<element::Float32::type>{1, 2, 3, 4};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<element::Bool::type>{1, 2, 3, 4}), result->get_vector());
}
...@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch) ...@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
} }
} }
// TEST(type_prop, convert_deduce)
// Tests for dot product. {
// // Deduce type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_correct)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 3, 4}));
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 14, 4}));
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_deduce_scalar_2d) TEST(type_prop, dot_deduce_scalar_2d)
{ {
// Deduce type for scalar/matrix arguments // Deduce type for scalar/matrix arguments
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment