Unverified Commit 1eb9f9bf authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Update to enable pass backend unit tests (#904)

* get all ops working

* enable autodiff tests for IE backend
parent a8a68452
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "ngraph/runtime/ie/ie_backend.hpp" #include "ngraph/runtime/ie/ie_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp" #include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
...@@ -145,11 +148,23 @@ bool runtime::ie::IE_Backend::call(shared_ptr<Function> function, ...@@ -145,11 +148,23 @@ bool runtime::ie::IE_Backend::call(shared_ptr<Function> function,
} }
// get op type // get op type
element::Type type = op->get_element_type(); element::Type type;
if (!op->get_inputs().empty()) if (dynamic_pointer_cast<op::util::BinaryElementwiseComparison>(op) ||
dynamic_pointer_cast<op::Select>(op))
{
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_inputs().at(1).get_tensor().get_element_type();
}
else if (dynamic_pointer_cast<op::Convert>(op))
{ {
type = op->get_inputs().at(0).get_tensor().get_element_type(); type = op->get_inputs().at(0).get_tensor().get_element_type();
} }
else
{
type = op->get_element_type();
}
generate_calls(type, *op, op_outputs, op_inputs); generate_calls(type, *op, op_outputs, op_inputs);
......
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp" #include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
...@@ -58,6 +59,7 @@ ...@@ -58,6 +59,7 @@
#include "ngraph/runtime/reference/ceiling.hpp" #include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp" #include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp" #include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp" #include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp" #include "ngraph/runtime/reference/cos.hpp"
...@@ -93,6 +95,8 @@ ...@@ -93,6 +95,8 @@
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp" #include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp" #include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp" #include "ngraph/runtime/reference/sinh.hpp"
...@@ -114,649 +118,768 @@ namespace ngraph ...@@ -114,649 +118,768 @@ namespace ngraph
{ {
namespace ie namespace ie
{ {
class IE_Backend : public Backend class IE_Backend;
{ }
public: }
std::shared_ptr<TensorView> create_tensor(const element::Type& type, }
const Shape& shape, class ngraph::runtime::ie::IE_Backend : public Backend
void* memory_pointer) override; {
public:
std::shared_ptr<TensorView>
create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<TensorView> create_tensor(const element::Type& type, std::shared_ptr<TensorView> create_tensor(const element::Type& type,
const Shape& shape) override; const Shape& shape) override;
bool compile(std::shared_ptr<Function> function) override; bool compile(std::shared_ptr<Function> function) override;
bool call(std::shared_ptr<Function> function, bool call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<TensorView>>& outputs, const std::vector<std::shared_ptr<TensorView>>& outputs,
const std::vector<std::shared_ptr<TensorView>>& intputs) override; const std::vector<std::shared_ptr<TensorView>>& intputs) override;
private: private:
static bool init; static bool init;
void generate_calls(const element::Type& type, void generate_calls(const element::Type& type,
Node& op, Node& op,
const std::vector<std::shared_ptr<HostTensorView>>& outputs, const std::vector<std::shared_ptr<HostTensorView>>& outputs,
const std::vector<std::shared_ptr<HostTensorView>>& inputs); const std::vector<std::shared_ptr<HostTensorView>>& inputs);
template <typename T> template <typename T>
void op_engine(Node& node, void op_engine(Node& node,
const std::vector<std::shared_ptr<HostTensorView>>& out, const std::vector<std::shared_ptr<HostTensorView>>& out,
const std::vector<std::shared_ptr<HostTensorView>>& args) const std::vector<std::shared_ptr<HostTensorView>>& args)
{ {
std::string node_op = node.description(); std::string node_op = node.description();
if (node_op == "Abs") if (node_op == "Abs")
{ {
reference::abs<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::abs<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Acos") else if (node_op == "Acos")
{ {
reference::acos<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::acos<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Add") else if (node_op == "Add")
{ {
reference::add<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::add<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
else if (node_op == "AllReduce") else if (node_op == "AllReduce")
{ {
reference::allreduce<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::allreduce<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_element_type(), args[0]->get_element_type(),
static_cast<int>(args[0]->get_element_count())); static_cast<int>(args[0]->get_element_count()));
} }
#endif #endif
else if (node_op == "And") else if (node_op == "And")
{ {
reference::logical_and(reinterpret_cast<char*>(args[0]->get_data_ptr()), reference::logical_and(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()), reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()), reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Asin") else if (node_op == "Asin")
{ {
reference::asin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::asin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Atan") else if (node_op == "Atan")
{ {
reference::atan<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::atan<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "AvgPool") else if (node_op == "AvgPool")
{ {
op::AvgPool* avg_pool = dynamic_cast<op::AvgPool*>(&node); op::AvgPool* avg_pool = dynamic_cast<op::AvgPool*>(&node);
reference::avg_pool<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
avg_pool->get_padding_below(),
avg_pool->get_padding_above(),
avg_pool->get_include_padding_in_avg_computation());
}
else if (node_op == "AvgPoolBackprop")
{
op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
apb->get_window_shape(),
apb->get_window_movement_strides(),
apb->get_padding_below(),
apb->get_padding_above(),
apb->get_include_padding_in_avg_computation());
}
else if (node_op == "Broadcast")
{
op::Broadcast* broadcast = dynamic_cast<op::Broadcast*>(&node);
Shape in_shape = args[0]->get_shape();
Shape out_shape = out[0]->get_shape();
AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
in_shape,
out_shape,
broadcast_axes);
}
else if (node_op == "Ceiling")
{
reference::ceiling<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Concat")
{
const op::Concat* concat = static_cast<const op::Concat*>(&node);
std::vector<const T*> in_args;
std::vector<Shape> in_shapes;
for (std::shared_ptr<HostTensorView> arg : args)
{
in_args.push_back(reinterpret_cast<T*>(arg->get_data_ptr()));
in_shapes.push_back(arg->get_shape());
}
reference::concat<T>(in_args,
reinterpret_cast<T*>(out[0]->get_data_ptr()),
in_shapes,
out[0]->get_shape(),
concat->get_concatenation_axis());
}
else if (node_op == "Constant")
{
const op::Constant* c = static_cast<const op::Constant*>(&node);
reference::constant<T>(reinterpret_cast<const T*>(c->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Convolution")
{
auto c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides(),
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
}
else if (node_op == "ConvolutionBackpropFilters")
{
auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
}
else if (node_op == "ConvolutionBackpropData")
{
// Note that args[1] and args[0] are switched here from the usual order.
auto c = static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[1]->get_shape(),
args[0]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
}
else if (node_op == "Cos")
{
reference::cos<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Cosh")
{
reference::cosh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Divide")
{
reference::divide<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Dot")
{
op::Dot* dot = dynamic_cast<op::Dot*>(&node);
reference::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
dot->get_reduction_axes_count());
}
else if (node_op == "Equal") reference::avg_pool<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
{ reinterpret_cast<T*>(out[0]->get_data_ptr()),
reference::equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), args[0]->get_shape(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), out[0]->get_shape(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), avg_pool->get_window_shape(),
out[0]->get_element_count()); avg_pool->get_window_movement_strides(),
} avg_pool->get_padding_below(),
else if (node_op == "Exp") avg_pool->get_padding_above(),
{ avg_pool->get_include_padding_in_avg_computation());
reference::exp<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), }
reinterpret_cast<T*>(out[0]->get_data_ptr()), else if (node_op == "AvgPoolBackprop")
out[0]->get_element_count()); {
} op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node);
else if (node_op == "Floor") reference::avg_pool_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
{
reference::floor<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); args[0]->get_shape(),
} out[0]->get_shape(),
else if (node_op == "FunctionCall") apb->get_window_shape(),
{ apb->get_window_movement_strides(),
std::shared_ptr<Function> function = node.get_functions()[0]; apb->get_padding_below(),
apb->get_padding_above(),
apb->get_include_padding_in_avg_computation());
}
else if (node_op == "Broadcast")
{
op::Broadcast* broadcast = dynamic_cast<op::Broadcast*>(&node);
Shape in_shape = args[0]->get_shape();
Shape out_shape = out[0]->get_shape();
AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
in_shape,
out_shape,
broadcast_axes);
}
else if (node_op == "Ceiling")
{
reference::ceiling<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Concat")
{
const op::Concat* concat = static_cast<const op::Concat*>(&node);
std::vector<const T*> in_args;
std::vector<Shape> in_shapes;
for (std::shared_ptr<HostTensorView> arg : args)
{
in_args.push_back(reinterpret_cast<T*>(arg->get_data_ptr()));
in_shapes.push_back(arg->get_shape());
}
reference::concat<T>(in_args,
reinterpret_cast<T*>(out[0]->get_data_ptr()),
in_shapes,
out[0]->get_shape(),
concat->get_concatenation_axis());
}
else if (node_op == "Constant")
{
const op::Constant* c = static_cast<const op::Constant*>(&node);
reference::constant<T>(reinterpret_cast<const T*>(c->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Convert")
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
if (type == element::boolean)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::f32)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<float*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::f64)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<double*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i8)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int8_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i16)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int16_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i32)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int32_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::i64)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<int64_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u8)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint8_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u16)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint16_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u32)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint32_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (type == element::u64)
{
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<uint64_t*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
}
else if (node_op == "Convolution")
{
auto c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides(),
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides(),
0,
1,
1,
0,
0,
1,
false);
}
else if (node_op == "ConvolutionBackpropFilters")
{
auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
1,
0,
0,
1,
1,
0,
false);
}
else if (node_op == "ConvolutionBackpropData")
{
// Note that args[1] and args[0] are switched here from the usual order.
auto c = static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[1]->get_shape(),
args[0]->get_shape(),
out[0]->get_shape(),
c->get_window_movement_strides_backward(),
c->get_window_dilation_strides_backward(),
c->get_padding_below_backward(),
c->get_padding_above_backward(),
c->get_data_dilation_strides_backward(),
0,
1,
0,
1,
0,
1,
true);
}
else if (node_op == "Cos")
{
reference::cos<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Cosh")
{
reference::cosh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Divide")
{
reference::divide<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Dot")
{
op::Dot* dot = dynamic_cast<op::Dot*>(&node);
std::vector<std::shared_ptr<runtime::TensorView>> outputs; reference::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
for (auto tv : out) reinterpret_cast<T*>(args[1]->get_data_ptr()),
{ reinterpret_cast<T*>(out[0]->get_data_ptr()),
outputs.push_back(std::static_pointer_cast<runtime::TensorView>(tv)); args[0]->get_shape(),
} args[1]->get_shape(),
out[0]->get_shape(),
dot->get_reduction_axes_count());
}
std::vector<std::shared_ptr<runtime::TensorView>> inputs; else if (node_op == "Equal")
for (auto tv : args) {
{ reference::equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
inputs.push_back(std::static_pointer_cast<runtime::TensorView>(tv)); reinterpret_cast<T*>(args[1]->get_data_ptr()),
} reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Exp")
{
reference::exp<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Floor")
{
reference::floor<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "FunctionCall")
{
std::shared_ptr<Function> function = node.get_functions()[0];
call(function, outputs, inputs); std::vector<std::shared_ptr<runtime::TensorView>> outputs;
} for (auto tv : out)
else if (node_op == "Greater") {
{ outputs.push_back(std::static_pointer_cast<runtime::TensorView>(tv));
reference::greater<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), }
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "GreaterEq")
{
reference::greater_eq<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Less")
{
reference::less<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "LessEq")
{
reference::less_eq<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Log")
{
reference::log<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Max")
{
const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
max->get_reduction_axes());
}
else if (node_op == "Maximum")
{
reference::maximum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "MaxPool")
{
op::MaxPool* max_pool = dynamic_cast<op::MaxPool*>(&node);
reference::max_pool<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), std::vector<std::shared_ptr<runtime::TensorView>> inputs;
reinterpret_cast<T*>(out[0]->get_data_ptr()), for (auto tv : args)
args[0]->get_shape(), {
out[0]->get_shape(), inputs.push_back(std::static_pointer_cast<runtime::TensorView>(tv));
max_pool->get_window_shape(), }
max_pool->get_window_movement_strides(),
max_pool->get_padding_below(),
max_pool->get_padding_above());
}
else if (node_op == "MaxPoolBackprop")
{
op::MaxPoolBackprop* max_pool_backprop =
dynamic_cast<op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>( call(function, outputs, inputs);
reinterpret_cast<T*>(args[0]->get_data_ptr()), }
reinterpret_cast<T*>(args[1]->get_data_ptr()), else if (node_op == "Greater")
reinterpret_cast<T*>(out[0]->get_data_ptr()), {
args[1]->get_shape(), reference::greater<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
out[0]->get_shape(), reinterpret_cast<T*>(args[1]->get_data_ptr()),
max_pool_backprop->get_window_shape(), reinterpret_cast<char*>(out[0]->get_data_ptr()),
max_pool_backprop->get_window_movement_strides(), out[0]->get_element_count());
max_pool_backprop->get_padding_below(), }
max_pool_backprop->get_padding_above()); else if (node_op == "GreaterEq")
} {
else if (node_op == "Min") reference::greater_eq<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
{ reinterpret_cast<T*>(args[1]->get_data_ptr()),
const op::Min* min = static_cast<const op::Min*>(&node); reinterpret_cast<char*>(out[0]->get_data_ptr()),
reference::min<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), out[0]->get_element_count());
reinterpret_cast<T*>(out[0]->get_data_ptr()), }
args[0]->get_shape(), else if (node_op == "Less")
out[0]->get_shape(), {
min->get_reduction_axes()); reference::less<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
} reinterpret_cast<T*>(args[1]->get_data_ptr()),
else if (node_op == "Minimum") reinterpret_cast<char*>(out[0]->get_data_ptr()),
{ out[0]->get_element_count());
reference::minimum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), }
reinterpret_cast<T*>(args[1]->get_data_ptr()), else if (node_op == "LessEq")
reinterpret_cast<T*>(out[0]->get_data_ptr()), {
out[0]->get_element_count()); reference::less_eq<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
} reinterpret_cast<T*>(args[1]->get_data_ptr()),
else if (node_op == "Multiply") reinterpret_cast<char*>(out[0]->get_data_ptr()),
{ out[0]->get_element_count());
reference::multiply<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), }
reinterpret_cast<T*>(args[1]->get_data_ptr()), else if (node_op == "Log")
reinterpret_cast<T*>(out[0]->get_data_ptr()), {
out[0]->get_element_count()); reference::log<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
} reinterpret_cast<T*>(out[0]->get_data_ptr()),
else if (node_op == "Negative") out[0]->get_element_count());
{ }
reference::negate<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), else if (node_op == "Max")
reinterpret_cast<T*>(out[0]->get_data_ptr()), {
out[0]->get_element_count()); const op::Max* max = static_cast<const op::Max*>(&node);
} reference::max<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
else if (node_op == "Not") reinterpret_cast<T*>(out[0]->get_data_ptr()),
{ args[0]->get_shape(),
reference::logical_not(reinterpret_cast<char*>(args[0]->get_data_ptr()), out[0]->get_shape(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), max->get_reduction_axes());
out[0]->get_element_count()); }
} else if (node_op == "Maximum")
else if (node_op == "NotEqual") {
{ reference::maximum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reference::not_equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_element_count());
out[0]->get_element_count()); }
} else if (node_op == "MaxPool")
else if (node_op == "OneHot") {
{ op::MaxPool* max_pool = dynamic_cast<op::MaxPool*>(&node);
auto oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
oh->get_one_hot_axis());
}
else if (node_op == "Or")
{
reference::logical_or(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Parameter")
{
}
else if (node_op == "Pad")
{
op::Pad* pad = dynamic_cast<op::Pad*>(&node);
reference::pad(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::max_pool<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_shape(),
node.get_inputs().at(0).get_shape(), out[0]->get_shape(),
node.get_output_shape(0), max_pool->get_window_shape(),
pad->get_padding_below(), max_pool->get_window_movement_strides(),
pad->get_padding_above(), max_pool->get_padding_below(),
pad->get_padding_interior()); max_pool->get_padding_above());
} }
else if (node_op == "Power") else if (node_op == "MaxPoolBackprop")
{ {
reference::power<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), op::MaxPoolBackprop* max_pool_backprop = dynamic_cast<op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); args[1]->get_shape(),
} out[0]->get_shape(),
else if (node_op == "Product") max_pool_backprop->get_window_shape(),
{ max_pool_backprop->get_window_movement_strides(),
const op::Product* product = static_cast<const op::Product*>(&node); max_pool_backprop->get_padding_below(),
reference::product<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), max_pool_backprop->get_padding_above());
reinterpret_cast<T*>(out[0]->get_data_ptr()), }
args[0]->get_shape(), else if (node_op == "Min")
out[0]->get_shape(), {
product->get_reduction_axes()); const op::Min* min = static_cast<const op::Min*>(&node);
} reference::min<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
else if (node_op == "Reduce") reinterpret_cast<T*>(out[0]->get_data_ptr()),
{ args[0]->get_shape(),
op::Reduce* reduce = dynamic_cast<op::Reduce*>(&node); out[0]->get_shape(),
std::shared_ptr<Function> reduction_function = reduce->get_functions()[0]; min->get_reduction_axes());
}
else if (node_op == "Minimum")
{
reference::minimum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Multiply")
{
reference::multiply<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Negative")
{
reference::negate<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Not")
{
reference::logical_not(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "NotEqual")
{
reference::not_equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "OneHot")
{
auto oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
oh->get_one_hot_axis());
}
else if (node_op == "Or")
{
reference::logical_or(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Parameter")
{
}
else if (node_op == "Pad")
{
op::Pad* pad = dynamic_cast<op::Pad*>(&node);
reference::pad(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
pad->get_padding_below(),
pad->get_padding_above(),
pad->get_padding_interior());
}
else if (node_op == "Power")
{
reference::power<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Product")
{
const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
product->get_reduction_axes());
}
else if (node_op == "Reduce")
{
op::Reduce* reduce = dynamic_cast<op::Reduce*>(&node);
std::shared_ptr<Function> reduction_function = reduce->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
T y) -> T { auto tx = std::make_shared<HostTensorView>(
auto tx = std::make_shared<HostTensorView>( node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_temp_x");
node.get_inputs().at(0).get_element_type(), auto ty = std::make_shared<HostTensorView>(
Shape{}, node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_temp_y");
"reduce_temp_x"); auto tr = std::make_shared<HostTensorView>(
auto ty = std::make_shared<HostTensorView>( node.get_output_element_type(0), Shape{}, "reduce_temp_r");
node.get_inputs().at(1).get_element_type(), *(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
Shape{}, *(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
"reduce_temp_y"); call(reduction_function, {tr}, {tx, ty});
auto tr = std::make_shared<HostTensorView>( return *(reinterpret_cast<T*>(tr->get_data_ptr()));
node.get_output_element_type(0), Shape{}, "reduce_temp_r"); };
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(reduction_function, {tr}, {tx, ty});
return *(reinterpret_cast<T*>(tr->get_data_ptr()));
};
reference::reduce(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reduce(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
node.get_inputs().at(0).get_shape(), node.get_inputs().at(0).get_shape(),
node.get_output_shape(0), node.get_output_shape(0),
reduce->get_reduction_axes(), reduce->get_reduction_axes(),
f); f);
} }
else if (node_op == "ReduceWindow") else if (node_op == "ReduceWindow")
{ {
op::ReduceWindow* reduce_window = dynamic_cast<op::ReduceWindow*>(&node); op::ReduceWindow* reduce_window = dynamic_cast<op::ReduceWindow*>(&node);
std::shared_ptr<Function> reduction_function = std::shared_ptr<Function> reduction_function = reduce_window->get_functions()[0];
reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
T y) -> T { auto tx = std::make_shared<HostTensorView>(
auto tx = std::make_shared<HostTensorView>( node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_window_temp_x");
node.get_inputs().at(0).get_element_type(), auto ty = std::make_shared<HostTensorView>(
Shape{}, node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_window_temp_y");
"reduce_window_temp_x"); auto tr = std::make_shared<HostTensorView>(
auto ty = std::make_shared<HostTensorView>( node.get_output_element_type(0), Shape{}, "reduce_window_temp_r");
node.get_inputs().at(1).get_element_type(), *(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
Shape{}, *(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
"reduce_window_temp_y"); call(reduction_function, {tr}, {tx, ty});
auto tr = std::make_shared<HostTensorView>( return *(reinterpret_cast<T*>(tr->get_data_ptr()));
node.get_output_element_type(0), Shape{}, "reduce_window_temp_r"); };
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(reduction_function, {tr}, {tx, ty});
return *(reinterpret_cast<T*>(tr->get_data_ptr()));
};
reference::reduce_window(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reduce_window(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
node.get_inputs().at(0).get_shape(), node.get_inputs().at(0).get_shape(),
node.get_output_shape(0), node.get_output_shape(0),
f, f,
reduce_window->get_window_shape(), reduce_window->get_window_shape(),
reduce_window->get_window_movement_strides()); reduce_window->get_window_movement_strides());
} }
else if (node_op == "Relu") else if (node_op == "Relu")
{ {
reference::relu<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::relu<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "ReluBackprop") else if (node_op == "ReluBackprop")
{ {
reference::relu_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::relu_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "ReplaceSlice") else if (node_op == "ReplaceSlice")
{ {
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node); const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::replace_slice<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[1]->get_shape(), args[1]->get_shape(),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
slice->get_strides(), slice->get_strides(),
out[0]->get_shape()); out[0]->get_shape());
} }
else if (node_op == "Reshape") else if (node_op == "Reshape")
{ {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(&node); op::Reshape* reshape = dynamic_cast<op::Reshape*>(&node);
reference::reshape(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reshape(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(), args[0]->get_shape(),
reshape->get_input_order(), reshape->get_input_order(),
out[0]->get_shape()); out[0]->get_shape());
} }
else if (node_op == "Result") else if (node_op == "Result")
{ {
op::Result* res = dynamic_cast<op::Result*>(&node); op::Result* res = dynamic_cast<op::Result*>(&node);
reference::result(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::result(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
shape_size(res->get_shape())); shape_size(res->get_shape()));
} }
else if (node_op == "Reverse") else if (node_op == "Reverse")
{ {
op::Reverse* reverse = dynamic_cast<op::Reverse*>(&node); op::Reverse* reverse = dynamic_cast<op::Reverse*>(&node);
reference::reverse(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reverse(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
reverse->get_reversed_axes()); reverse->get_reversed_axes());
} }
else if (node_op == "Sign") else if (node_op == "Select")
{ {
reference::sign<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::select<T>(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
out[0]->get_element_count()); reinterpret_cast<T*>(args[2]->get_data_ptr()),
} reinterpret_cast<T*>(out[0]->get_data_ptr()),
else if (node_op == "Sin") out[0]->get_element_count());
{ }
reference::sin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), else if (node_op == "SelectAndScatter")
reinterpret_cast<T*>(out[0]->get_data_ptr()), {
out[0]->get_element_count()); ngraph::op::SelectAndScatter* select_and_scatter =
} dynamic_cast<ngraph::op::SelectAndScatter*>(&node);
else if (node_op == "Sinh")
{ std::shared_ptr<ngraph::Function> selection_function =
reference::sinh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), select_and_scatter->get_functions()[0];
reinterpret_cast<T*>(out[0]->get_data_ptr()), std::function<bool(T, T)> f_selection = [this, &node, selection_function](T x,
out[0]->get_element_count()); T y) -> bool {
} auto tx = std::make_shared<runtime::HostTensorView>(
else if (node_op == "Slice") node.get_inputs().at(0).get_element_type(), Shape{}, "selection_temp_x");
{ auto ty = std::make_shared<runtime::HostTensorView>(
const op::Slice* slice = static_cast<const op::Slice*>(&node); node.get_inputs().at(1).get_element_type(), Shape{}, "selection_temp_y");
reference::slice<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), auto tr = std::make_shared<runtime::HostTensorView>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), element::boolean, Shape{}, "selection_temp_r");
args[0]->get_shape(), *(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
slice->get_lower_bounds(), *(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
slice->get_upper_bounds(), call(selection_function, {tr}, {tx, ty});
slice->get_strides(), return *(reinterpret_cast<char*>(tr->get_data_ptr()));
out[0]->get_shape()); };
}
else if (node_op == "Softmax") std::shared_ptr<ngraph::Function> scatter_function =
{ select_and_scatter->get_functions()[1];
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node); std::function<T(T, T)> f_scatter = [this, &node, scatter_function](T x, T y) -> T {
reference::softmax<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), auto tx = std::make_shared<runtime::HostTensorView>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), node.get_inputs().at(0).get_element_type(), Shape{}, "scatter_temp_x");
out[0]->get_shape(), auto ty = std::make_shared<runtime::HostTensorView>(
softmax->get_axes()); node.get_inputs().at(1).get_element_type(), Shape{}, "scatter_temp_y");
} auto tr = std::make_shared<runtime::HostTensorView>(
else if (node_op == "Sqrt") node.get_output_element_type(0), Shape{}, "scatter_temp_r");
{ *(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
reference::sqrt<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), *(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
reinterpret_cast<T*>(out[0]->get_data_ptr()), call(scatter_function, {tr}, {tx, ty});
out[0]->get_element_count()); return *(reinterpret_cast<T*>(tr->get_data_ptr()));
}
else if (node_op == "Subtract")
{
reference::subtract<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Sum")
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
sum->get_reduction_axes());
}
else if (node_op == "Tan")
{
reference::tan<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Tanh")
{
reference::tanh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
std::stringstream ss;
ss << "unsupported op " << node_op;
throw ngraph_error(ss.str());
}
}
}; };
reference::select_and_scatter<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
f_selection,
f_scatter,
select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides());
}
else if (node_op == "Sign")
{
reference::sign<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Sin")
{
reference::sin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Sinh")
{
reference::sinh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Slice")
{
const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
out[0]->get_shape());
}
else if (node_op == "Softmax")
{
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_shape(),
softmax->get_axes());
}
else if (node_op == "Sqrt")
{
reference::sqrt<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Subtract")
{
reference::subtract<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Sum")
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
sum->get_reduction_axes());
}
else if (node_op == "Tan")
{
reference::tan<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Tanh")
{
reference::tanh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
std::stringstream ss;
ss << "unsupported op " << node_op;
throw ngraph_error(ss.str());
} }
} }
} };
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <iostream>
namespace ngraph namespace ngraph
{ {
......
...@@ -829,7 +829,6 @@ TEST(${BACKEND_NAME}, backwards_log) ...@@ -829,7 +829,6 @@ TEST(${BACKEND_NAME}, backwards_log)
TEST(${BACKEND_NAME}, backwards_maximum) TEST(${BACKEND_NAME}, backwards_maximum)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}"); // no convert support
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-1.0f, 1.0f); test::Uniform<float> rng(-1.0f, 1.0f);
...@@ -848,7 +847,6 @@ TEST(${BACKEND_NAME}, backwards_maximum) ...@@ -848,7 +847,6 @@ TEST(${BACKEND_NAME}, backwards_maximum)
TEST(${BACKEND_NAME}, backwards_minimum) TEST(${BACKEND_NAME}, backwards_minimum)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}"); // no convert support
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<float> rng(-1.0f, 1.0f); test::Uniform<float> rng(-1.0f, 1.0f);
...@@ -1019,7 +1017,6 @@ TEST(${BACKEND_NAME}, backwards_reshape) ...@@ -1019,7 +1017,6 @@ TEST(${BACKEND_NAME}, backwards_reshape)
TEST(${BACKEND_NAME}, backwards_select) TEST(${BACKEND_NAME}, backwards_select)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
...@@ -1049,7 +1046,6 @@ TEST(${BACKEND_NAME}, backwards_select) ...@@ -1049,7 +1046,6 @@ TEST(${BACKEND_NAME}, backwards_select)
TEST(${BACKEND_NAME}, backwards_select_nested) TEST(${BACKEND_NAME}, backwards_select_nested)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}"); SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
......
...@@ -1349,7 +1349,6 @@ TEST(${BACKEND_NAME}, notequal) ...@@ -1349,7 +1349,6 @@ TEST(${BACKEND_NAME}, notequal)
TEST(${BACKEND_NAME}, select) TEST(${BACKEND_NAME}, select)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2}; Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape); auto A = make_shared<op::Parameter>(element::boolean, shape);
...@@ -1767,7 +1766,6 @@ TEST(${BACKEND_NAME}, broadcast_matrix_2) ...@@ -1767,7 +1766,6 @@ TEST(${BACKEND_NAME}, broadcast_matrix_2)
TEST(${BACKEND_NAME}, convert_int32_float32) TEST(${BACKEND_NAME}, convert_int32_float32)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2}; Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape); auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = auto f =
...@@ -1786,7 +1784,6 @@ TEST(${BACKEND_NAME}, convert_int32_float32) ...@@ -1786,7 +1784,6 @@ TEST(${BACKEND_NAME}, convert_int32_float32)
TEST(${BACKEND_NAME}, convert_int32_bool) TEST(${BACKEND_NAME}, convert_int32_bool)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2}; Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape); auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean), auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean),
...@@ -1805,7 +1802,6 @@ TEST(${BACKEND_NAME}, convert_int32_bool) ...@@ -1805,7 +1802,6 @@ TEST(${BACKEND_NAME}, convert_int32_bool)
TEST(${BACKEND_NAME}, convert_float32_bool) TEST(${BACKEND_NAME}, convert_float32_bool)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2}; Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean), auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean),
...@@ -5148,7 +5144,6 @@ TEST(${BACKEND_NAME}, reduce_window_emulating_max_pool_2d_1channel_1image_stride ...@@ -5148,7 +5144,6 @@ TEST(${BACKEND_NAME}, reduce_window_emulating_max_pool_2d_1channel_1image_stride
// //
TEST(${BACKEND_NAME}, select_and_scatter_with_overlap) TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{}; Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a); auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
...@@ -5203,7 +5198,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_with_overlap) ...@@ -5203,7 +5198,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
// //
TEST(${BACKEND_NAME}, select_and_scatter_without_overlap) TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{}; Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a); auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
...@@ -5258,7 +5252,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_without_overlap) ...@@ -5258,7 +5252,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
// //
TEST(${BACKEND_NAME}, select_and_scatter_3d_without_overlap) TEST(${BACKEND_NAME}, select_and_scatter_3d_without_overlap)
{ {
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{}; Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a); auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment