Unverified Commit 1eb9f9bf authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Update to enable pass backend unit tests (#904)

* get all ops working

* enable autodiff tests for IE backend
parent a8a68452
......@@ -17,6 +17,9 @@
#include "ngraph/runtime/ie/ie_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
......@@ -145,11 +148,23 @@ bool runtime::ie::IE_Backend::call(shared_ptr<Function> function,
}
// get op type
element::Type type = op->get_element_type();
if (!op->get_inputs().empty())
element::Type type;
if (dynamic_pointer_cast<op::util::BinaryElementwiseComparison>(op) ||
dynamic_pointer_cast<op::Select>(op))
{
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_inputs().at(1).get_tensor().get_element_type();
}
else if (dynamic_pointer_cast<op::Convert>(op))
{
type = op->get_inputs().at(0).get_tensor().get_element_type();
}
else
{
type = op->get_element_type();
}
generate_calls(type, *op, op_outputs, op_inputs);
......
......@@ -47,6 +47,7 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
......@@ -58,6 +59,7 @@
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp"
......@@ -93,6 +95,8 @@
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
......@@ -114,12 +118,15 @@ namespace ngraph
{
namespace ie
{
class IE_Backend : public Backend
{
public:
std::shared_ptr<TensorView> create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer) override;
class IE_Backend;
}
}
}
class ngraph::runtime::ie::IE_Backend : public Backend
{
public:
std::shared_ptr<TensorView>
create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<TensorView> create_tensor(const element::Type& type,
const Shape& shape) override;
......@@ -130,7 +137,7 @@ namespace ngraph
const std::vector<std::shared_ptr<TensorView>>& outputs,
const std::vector<std::shared_ptr<TensorView>>& intputs) override;
private:
private:
static bool init;
void generate_calls(const element::Type& type,
Node& op,
......@@ -207,8 +214,7 @@ namespace ngraph
else if (node_op == "AvgPoolBackprop")
{
op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reference::avg_pool_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
......@@ -259,6 +265,83 @@ namespace ngraph
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Convert")
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
if (type == element::boolean)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::f32)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<float*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::f64)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<double*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i8)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int8_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i16)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int16_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i32)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int32_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i64)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int64_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u8)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint8_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u16)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint16_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u32)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint32_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u64)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint64_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
}
else if (node_op == "Convolution")
{
auto c = static_cast<const op::Convolution*>(&node);
......@@ -460,11 +543,9 @@ namespace ngraph
}
else if (node_op == "MaxPoolBackprop")
{
op::MaxPoolBackprop* max_pool_backprop =
dynamic_cast<op::MaxPoolBackprop*>(&node);
op::MaxPoolBackprop* max_pool_backprop = dynamic_cast<op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reference::max_pool_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[1]->get_shape(),
......@@ -569,16 +650,11 @@ namespace ngraph
op::Reduce* reduce = dynamic_cast<op::Reduce*>(&node);
std::shared_ptr<Function> reduction_function = reduce->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x,
T y) -> T {
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::make_shared<HostTensorView>(
node.get_inputs().at(0).get_element_type(),
Shape{},
"reduce_temp_x");
node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_temp_x");
auto ty = std::make_shared<HostTensorView>(
node.get_inputs().at(1).get_element_type(),
Shape{},
"reduce_temp_y");
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_temp_y");
auto tr = std::make_shared<HostTensorView>(
node.get_output_element_type(0), Shape{}, "reduce_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
......@@ -598,19 +674,13 @@ namespace ngraph
else if (node_op == "ReduceWindow")
{
op::ReduceWindow* reduce_window = dynamic_cast<op::ReduceWindow*>(&node);
std::shared_ptr<Function> reduction_function =
reduce_window->get_functions()[0];
std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x,
T y) -> T {
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::make_shared<HostTensorView>(
node.get_inputs().at(0).get_element_type(),
Shape{},
"reduce_window_temp_x");
node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_window_temp_x");
auto ty = std::make_shared<HostTensorView>(
node.get_inputs().at(1).get_element_type(),
Shape{},
"reduce_window_temp_y");
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_window_temp_y");
auto tr = std::make_shared<HostTensorView>(
node.get_output_element_type(0), Shape{}, "reduce_window_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
......@@ -678,6 +748,62 @@ namespace ngraph
out[0]->get_shape(),
reverse->get_reversed_axes());
}
else if (node_op == "Select")
{
reference::select<T>(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "SelectAndScatter")
{
ngraph::op::SelectAndScatter* select_and_scatter =
dynamic_cast<ngraph::op::SelectAndScatter*>(&node);
std::shared_ptr<ngraph::Function> selection_function =
select_and_scatter->get_functions()[0];
std::function<bool(T, T)> f_selection = [this, &node, selection_function](T x,
T y) -> bool {
auto tx = std::make_shared<runtime::HostTensorView>(
node.get_inputs().at(0).get_element_type(), Shape{}, "selection_temp_x");
auto ty = std::make_shared<runtime::HostTensorView>(
node.get_inputs().at(1).get_element_type(), Shape{}, "selection_temp_y");
auto tr = std::make_shared<runtime::HostTensorView>(
element::boolean, Shape{}, "selection_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(selection_function, {tr}, {tx, ty});
return *(reinterpret_cast<char*>(tr->get_data_ptr()));
};
std::shared_ptr<ngraph::Function> scatter_function =
select_and_scatter->get_functions()[1];
std::function<T(T, T)> f_scatter = [this, &node, scatter_function](T x, T y) -> T {
auto tx = std::make_shared<runtime::HostTensorView>(
node.get_inputs().at(0).get_element_type(), Shape{}, "scatter_temp_x");
auto ty = std::make_shared<runtime::HostTensorView>(
node.get_inputs().at(1).get_element_type(), Shape{}, "scatter_temp_y");
auto tr = std::make_shared<runtime::HostTensorView>(
node.get_output_element_type(0), Shape{}, "scatter_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(scatter_function, {tr}, {tx, ty});
return *(reinterpret_cast<T*>(tr->get_data_ptr()));
};
reference::select_and_scatter<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
f_selection,
f_scatter,
select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides());
}
else if (node_op == "Sign")
{
reference::sign<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......@@ -756,7 +882,4 @@ namespace ngraph
throw ngraph_error(ss.str());
}
}
};
}
}
}
};
......@@ -17,6 +17,7 @@
#pragma once
#include <cstddef>
#include <iostream>
namespace ngraph
{
......
......@@ -829,7 +829,6 @@ TEST(${BACKEND_NAME}, backwards_log)
TEST(${BACKEND_NAME}, backwards_maximum)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}"); // no convert support
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-1.0f, 1.0f);
......@@ -848,7 +847,6 @@ TEST(${BACKEND_NAME}, backwards_maximum)
TEST(${BACKEND_NAME}, backwards_minimum)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}"); // no convert support
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-1.0f, 1.0f);
......@@ -1019,7 +1017,6 @@ TEST(${BACKEND_NAME}, backwards_reshape)
TEST(${BACKEND_NAME}, backwards_select)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
......@@ -1049,7 +1046,6 @@ TEST(${BACKEND_NAME}, backwards_select)
TEST(${BACKEND_NAME}, backwards_select_nested)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
......
......@@ -1349,7 +1349,6 @@ TEST(${BACKEND_NAME}, notequal)
TEST(${BACKEND_NAME}, select)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
......@@ -1767,7 +1766,6 @@ TEST(${BACKEND_NAME}, broadcast_matrix_2)
TEST(${BACKEND_NAME}, convert_int32_float32)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f =
......@@ -1786,7 +1784,6 @@ TEST(${BACKEND_NAME}, convert_int32_float32)
TEST(${BACKEND_NAME}, convert_int32_bool)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean),
......@@ -1805,7 +1802,6 @@ TEST(${BACKEND_NAME}, convert_int32_bool)
TEST(${BACKEND_NAME}, convert_float32_bool)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean),
......@@ -5148,7 +5144,6 @@ TEST(${BACKEND_NAME}, reduce_window_emulating_max_pool_2d_1channel_1image_stride
//
TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......@@ -5203,7 +5198,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
//
TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......@@ -5258,7 +5252,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
//
TEST(${BACKEND_NAME}, select_and_scatter_3d_without_overlap)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment