Unverified Commit 253d4cdf authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Pass HostTensor all the way down (#2454)

* pass HostTensor all the way down

* fix distributed build error
parent 0a3858a0
...@@ -182,22 +182,11 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime:: ...@@ -182,22 +182,11 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
return true; return true;
} }
void runtime::interpreter::INTExecutable::generate_calls( void runtime::interpreter::INTExecutable::generate_calls(const element::Type& type,
const element::Type& type, const NodeWrapper& op,
const NodeWrapper& op, const vector<shared_ptr<HostTensor>>& out,
const vector<shared_ptr<HostTensor>>& outputs, const vector<shared_ptr<HostTensor>>& in)
const vector<shared_ptr<HostTensor>>& inputs)
{ {
vector<void*> out;
vector<const void*> in;
for (auto t : outputs)
{
out.push_back(t->get_data_ptr());
}
for (auto t : inputs)
{
in.push_back(t->get_data_ptr());
}
stringstream ss; stringstream ss;
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
......
...@@ -181,8 +181,8 @@ private: ...@@ -181,8 +181,8 @@ private:
template <typename T> template <typename T>
void op_engine(const NodeWrapper& node_wrapper, void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out, const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<const void*>& args) const std::vector<std::shared_ptr<HostTensor>>& args)
{ {
const Node& node = node_wrapper.get_node(); const Node& node = node_wrapper.get_node();
std::string node_op = node.description(); std::string node_op = node.description();
...@@ -200,30 +200,30 @@ private: ...@@ -200,30 +200,30 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>( reference::abs<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Acos: case OP_TYPEID::Acos:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>( reference::acos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]), reference::add<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
{ {
const op::All* all = static_cast<const op::All*>(&node); const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]), reference::all(args[0]->get_data_ptr<const char>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
all->get_reduction_axes()); all->get_reduction_axes());
...@@ -231,8 +231,8 @@ private: ...@@ -231,8 +231,8 @@ private:
} }
case OP_TYPEID::AllReduce: { case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED_ENABLE #ifdef NGRAPH_DISTRIBUTED_ENABLE
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])), reference::allreduce<T>(args[0]->get_data_ptr<T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_element_type(0), node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))));
#endif #endif
...@@ -241,17 +241,17 @@ private: ...@@ -241,17 +241,17 @@ private:
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]), reference::logical_and(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
const op::Any* any = static_cast<const op::Any*>(&node); const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]), reference::any(args[0]->get_data_ptr<const char>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
any->get_reduction_axes()); any->get_reduction_axes());
...@@ -263,16 +263,16 @@ private: ...@@ -263,16 +263,16 @@ private:
auto element_type = node.get_output_element_type(0); auto element_type = node.get_output_element_type(0);
if (element_type == element::i64) if (element_type == element::i64)
{ {
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]), reference::argmin<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmin->get_reduction_axis()); argmin->get_reduction_axis());
} }
else if (element_type == element::i32) else if (element_type == element::i32)
{ {
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]), reference::argmin<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmin->get_reduction_axis()); argmin->get_reduction_axis());
...@@ -289,16 +289,16 @@ private: ...@@ -289,16 +289,16 @@ private:
auto element_type = node.get_output_element_type(0); auto element_type = node.get_output_element_type(0);
if (element_type == element::i64) if (element_type == element::i64)
{ {
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]), reference::argmax<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmax->get_reduction_axis()); argmax->get_reduction_axis());
} }
else if (element_type == element::i32) else if (element_type == element::i32)
{ {
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]), reference::argmax<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmax->get_reduction_axis()); argmax->get_reduction_axis());
...@@ -313,22 +313,22 @@ private: ...@@ -313,22 +313,22 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>( reference::asin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Atan: case OP_TYPEID::Atan:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>( reference::atan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::AvgPool: case OP_TYPEID::AvgPool:
{ {
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node); const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(static_cast<const T*>(args[0]), reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
avg_pool->get_window_shape(), avg_pool->get_window_shape(),
...@@ -347,11 +347,10 @@ private: ...@@ -347,11 +347,10 @@ private:
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
} }
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]); bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get(); auto state = m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>( reference::generate_mask<T>(out[0]->get_data_ptr<T>(), element_count, state, training);
reinterpret_cast<T*>(out[0]), element_count, state, training);
break; break;
} }
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
...@@ -361,7 +360,7 @@ private: ...@@ -361,7 +360,7 @@ private:
size_t n = get_output_element->get_n(); size_t n = get_output_element->get_n();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size(); size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes); std::memcpy(out[0]->get_data_ptr<T>(), args[n]->get_data_ptr<T>(), num_bytes);
break; break;
} }
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
...@@ -369,12 +368,12 @@ private: ...@@ -369,12 +368,12 @@ private:
const ngraph::op::BatchNormTraining* bn = const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node); static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(), reference::batch_norm_training<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
static_cast<T*>(out[2]), out[2]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
...@@ -383,12 +382,12 @@ private: ...@@ -383,12 +382,12 @@ private:
const ngraph::op::BatchNormInference* bn = const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node); static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(), reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<const T*>(args[3]), args[3]->get_data_ptr<const T>(),
static_cast<const T*>(args[4]), args[4]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
...@@ -397,23 +396,23 @@ private: ...@@ -397,23 +396,23 @@ private:
const ngraph::op::BatchNormTrainingBackprop* bn_bprop = const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node); static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(), reference::batch_norm_backprop(bn_bprop->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<const T*>(args[3]), args[3]->get_data_ptr<const T>(),
static_cast<const T*>(args[4]), args[4]->get_data_ptr<const T>(),
static_cast<const T*>(args[5]), args[5]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
static_cast<T*>(out[2]), out[2]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
{ {
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node); const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]), reference::avg_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
apb->get_window_shape(), apb->get_window_shape(),
...@@ -429,8 +428,8 @@ private: ...@@ -429,8 +428,8 @@ private:
Shape in_shape = node.get_input_shape(0); Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0); Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes(); AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(static_cast<const T*>(args[0]), reference::broadcast<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
in_shape, in_shape,
out_shape, out_shape,
broadcast_axes); broadcast_axes);
...@@ -441,7 +440,7 @@ private: ...@@ -441,7 +440,7 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>( reference::ceiling<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
...@@ -451,11 +450,11 @@ private: ...@@ -451,11 +450,11 @@ private:
std::vector<Shape> in_shapes; std::vector<Shape> in_shapes;
for (size_t i = 0; i < node.get_input_size(); i++) for (size_t i = 0; i < node.get_input_size(); i++)
{ {
in_args.push_back(static_cast<const T*>(args[i])); in_args.push_back(args[i]->get_data_ptr<const T>());
in_shapes.push_back(node.get_input_shape(i)); in_shapes.push_back(node.get_input_shape(i));
} }
reference::concat<T>(in_args, reference::concat<T>(in_args,
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
in_shapes, in_shapes,
node.get_output_shape(0), node.get_output_shape(0),
concat->get_concatenation_axis()); concat->get_concatenation_axis());
...@@ -465,7 +464,7 @@ private: ...@@ -465,7 +464,7 @@ private:
{ {
const op::Constant* c = static_cast<const op::Constant*>(&node); const op::Constant* c = static_cast<const op::Constant*>(&node);
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::constant<T>(c->get_data_ptr<T>(), static_cast<T*>(out[0]), element_count); reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ScalarConstantLike: break; case OP_TYPEID::ScalarConstantLike: break;
...@@ -479,47 +478,56 @@ private: ...@@ -479,47 +478,56 @@ private:
{ {
case element::Type_t::boolean: case element::Type_t::boolean:
reference::convert<T>( reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
reference::convert<T>( reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count); out[0]->get_data_ptr<double>(),
element_count);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count); out[0]->get_data_ptr<int8_t>(),
element_count);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count); out[0]->get_data_ptr<int16_t>(),
element_count);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count); out[0]->get_data_ptr<int32_t>(),
element_count);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count); out[0]->get_data_ptr<int64_t>(),
element_count);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count); out[0]->get_data_ptr<uint8_t>(),
element_count);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count); out[0]->get_data_ptr<uint16_t>(),
element_count);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count); out[0]->get_data_ptr<uint32_t>(),
element_count);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count); out[0]->get_data_ptr<uint64_t>(),
element_count);
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
...@@ -532,9 +540,9 @@ private: ...@@ -532,9 +540,9 @@ private:
case OP_TYPEID::Convolution: case OP_TYPEID::Convolution:
{ {
const op::Convolution* c = static_cast<const op::Convolution*>(&node); const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]), reference::convolution<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
...@@ -556,9 +564,9 @@ private: ...@@ -556,9 +564,9 @@ private:
{ {
const op::ConvolutionBackpropFilters* c = const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node); static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]), reference::convolution<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
...@@ -581,9 +589,9 @@ private: ...@@ -581,9 +589,9 @@ private:
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c = const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node); static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(static_cast<const T*>(args[1]), reference::convolution<T>(args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
...@@ -605,14 +613,14 @@ private: ...@@ -605,14 +613,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>( reference::cos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Cosh: case OP_TYPEID::Cosh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>( reference::cosh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
...@@ -622,20 +630,20 @@ private: ...@@ -622,20 +630,20 @@ private:
if (type == element::f32) if (type == element::f32)
{ {
reference::dequantize<T>(static_cast<const T*>(args[0]), reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const float*>(args[1]), args[1]->get_data_ptr<const float>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<float*>(out[0]), out[0]->get_data_ptr<float>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
dequantize->get_axes()); dequantize->get_axes());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::dequantize<T>(static_cast<const T*>(args[0]), reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const double*>(args[1]), args[1]->get_data_ptr<const double>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<double*>(out[0]), out[0]->get_data_ptr<double>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
dequantize->get_axes()); dequantize->get_axes());
...@@ -652,9 +660,9 @@ private: ...@@ -652,9 +660,9 @@ private:
case OP_TYPEID::Divide: case OP_TYPEID::Divide:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]), reference::divide<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -662,9 +670,9 @@ private: ...@@ -662,9 +670,9 @@ private:
{ {
const op::Dot* dot = static_cast<const op::Dot*>(&node); const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(static_cast<const T*>(args[0]), reference::dot(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
...@@ -679,33 +687,33 @@ private: ...@@ -679,33 +687,33 @@ private:
if (type == element::f32) if (type == element::f32)
{ {
reference::embedding<T, float>(static_cast<const float*>(args[0]), reference::embedding<T, float>(args[0]->get_data_ptr<const float>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::embedding<T, double>(static_cast<const double*>(args[0]), reference::embedding<T, double>(args[0]->get_data_ptr<const double>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::embedding<T, int>(static_cast<const int*>(args[0]), reference::embedding<T, int>(args[0]->get_data_ptr<const int>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::i64) else if (type == element::i64)
{ {
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]), reference::embedding<T, int64_t>(args[0]->get_data_ptr<const int64_t>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
...@@ -719,9 +727,9 @@ private: ...@@ -719,9 +727,9 @@ private:
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]), reference::equal<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
...@@ -729,49 +737,49 @@ private: ...@@ -729,49 +737,49 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>( reference::exp<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Floor: case OP_TYPEID::Floor:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>( reference::floor<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]), reference::greater<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]), reference::greater_eq<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]), reference::less<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]), reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
...@@ -779,14 +787,14 @@ private: ...@@ -779,14 +787,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>( reference::log<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::LRN: case OP_TYPEID::LRN:
{ {
const op::LRN* lrn = static_cast<const op::LRN*>(&node); const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(static_cast<const T*>(args[0]), reference::lrn<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
lrn->get_alpha(), lrn->get_alpha(),
lrn->get_beta(), lrn->get_beta(),
...@@ -797,8 +805,8 @@ private: ...@@ -797,8 +805,8 @@ private:
case OP_TYPEID::Max: case OP_TYPEID::Max:
{ {
const op::Max* max = static_cast<const op::Max*>(&node); const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(static_cast<const T*>(args[0]), reference::max<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
max->get_reduction_axes()); max->get_reduction_axes());
...@@ -807,9 +815,9 @@ private: ...@@ -807,9 +815,9 @@ private:
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]), reference::maximum<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -817,8 +825,8 @@ private: ...@@ -817,8 +825,8 @@ private:
{ {
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node); const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(static_cast<const T*>(args[0]), reference::max_pool<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
max_pool->get_window_shape(), max_pool->get_window_shape(),
...@@ -832,9 +840,9 @@ private: ...@@ -832,9 +840,9 @@ private:
const op::MaxPoolBackprop* max_pool_backprop = const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node); static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]), reference::max_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
max_pool_backprop->get_window_shape(), max_pool_backprop->get_window_shape(),
...@@ -846,8 +854,8 @@ private: ...@@ -846,8 +854,8 @@ private:
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
const op::Min* min = static_cast<const op::Min*>(&node); const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(static_cast<const T*>(args[0]), reference::min<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
min->get_reduction_axes()); min->get_reduction_axes());
...@@ -856,18 +864,18 @@ private: ...@@ -856,18 +864,18 @@ private:
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]), reference::minimum<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]), reference::multiply<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -875,30 +883,30 @@ private: ...@@ -875,30 +883,30 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>( reference::negate<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Not: case OP_TYPEID::Not:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not( reference::logical_not(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]), reference::not_equal<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::OneHot: case OP_TYPEID::OneHot:
{ {
const op::OneHot* oh = static_cast<const op::OneHot*>(&node); const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(static_cast<const T*>(args[0]), reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
oh->get_one_hot_axis()); oh->get_one_hot_axis());
...@@ -907,9 +915,9 @@ private: ...@@ -907,9 +915,9 @@ private:
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]), reference::logical_or(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -918,9 +926,9 @@ private: ...@@ -918,9 +926,9 @@ private:
{ {
const op::Pad* pad = static_cast<const op::Pad*>(&node); const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(static_cast<const T*>(args[0]), reference::pad(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(), node.get_inputs().at(0).get_shape(),
node.get_output_shape(0), node.get_output_shape(0),
pad->get_padding_below(), pad->get_padding_below(),
...@@ -931,17 +939,17 @@ private: ...@@ -931,17 +939,17 @@ private:
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]), reference::power<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
const op::Product* product = static_cast<const op::Product*>(&node); const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(static_cast<const T*>(args[0]), reference::product<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
product->get_reduction_axes()); product->get_reduction_axes());
...@@ -954,10 +962,10 @@ private: ...@@ -954,10 +962,10 @@ private:
if (type == element::u8) if (type == element::u8)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const uint8_t*>(args[2]), args[2]->get_data_ptr<const uint8_t>(),
static_cast<uint8_t*>(out[0]), out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -965,10 +973,10 @@ private: ...@@ -965,10 +973,10 @@ private:
} }
else if (type == element::i8) else if (type == element::i8)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const int8_t*>(args[2]), args[2]->get_data_ptr<const int8_t>(),
static_cast<int8_t*>(out[0]), out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -976,10 +984,10 @@ private: ...@@ -976,10 +984,10 @@ private:
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const int32_t*>(args[2]), args[2]->get_data_ptr<const int32_t>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1009,24 +1017,24 @@ private: ...@@ -1009,24 +1017,24 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>( reference::relu<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ReluBackprop: case OP_TYPEID::ReluBackprop:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]), reference::relu_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
{ {
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node); const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(static_cast<const T*>(args[0]), reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -1037,8 +1045,8 @@ private: ...@@ -1037,8 +1045,8 @@ private:
case OP_TYPEID::Reshape: case OP_TYPEID::Reshape:
{ {
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node); const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
reference::reshape(static_cast<const T*>(args[0]), reference::reshape(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reshape->get_input_order(), reshape->get_input_order(),
node.get_output_shape(0)); node.get_output_shape(0));
...@@ -1047,16 +1055,16 @@ private: ...@@ -1047,16 +1055,16 @@ private:
case OP_TYPEID::Result: case OP_TYPEID::Result:
{ {
const op::Result* res = static_cast<const op::Result*>(&node); const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(static_cast<const T*>(args[0]), reference::result(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
shape_size(res->get_shape())); shape_size(res->get_shape()));
break; break;
} }
case OP_TYPEID::Reverse: case OP_TYPEID::Reverse:
{ {
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node); const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(static_cast<const T*>(args[0]), reference::reverse(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
reverse->get_reversed_axes()); reverse->get_reversed_axes());
...@@ -1068,12 +1076,12 @@ private: ...@@ -1068,12 +1076,12 @@ private:
if (node.get_input_element_type(1) == element::i32) if (node.get_input_element_type(1) == element::i32)
{ {
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]), reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reverse->get_batch_axis(), reverse->get_batch_axis(),
reverse->get_sequence_axis(), reverse->get_sequence_axis(),
static_cast<const int32_t*>(args[1])); args[1]->get_data_ptr<const int32_t>());
} }
else else
{ {
...@@ -1084,31 +1092,31 @@ private: ...@@ -1084,31 +1092,31 @@ private:
case OP_TYPEID::Select: case OP_TYPEID::Select:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]), reference::select<T>(args[0]->get_data_ptr<const char>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
{ {
reference::shape_of(node.get_input_shape(0), static_cast<uint64_t*>(out[0])); reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
break; break;
} }
case OP_TYPEID::Sigmoid: case OP_TYPEID::Sigmoid:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>( reference::sigmoid<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::SigmoidBackprop: case OP_TYPEID::SigmoidBackprop:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]), reference::sigmoid_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -1116,28 +1124,28 @@ private: ...@@ -1116,28 +1124,28 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>( reference::sign<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Sin: case OP_TYPEID::Sin:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>( reference::sin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Sinh: case OP_TYPEID::Sinh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>( reference::sinh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Slice: case OP_TYPEID::Slice:
{ {
const op::Slice* slice = static_cast<const op::Slice*>(&node); const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(static_cast<const T*>(args[0]), reference::slice<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -1148,8 +1156,8 @@ private: ...@@ -1148,8 +1156,8 @@ private:
case OP_TYPEID::Softmax: case OP_TYPEID::Softmax:
{ {
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node); const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(static_cast<const T*>(args[0]), reference::softmax<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_output_shape(0), node.get_output_shape(0),
softmax->get_axes()); softmax->get_axes());
break; break;
...@@ -1158,7 +1166,7 @@ private: ...@@ -1158,7 +1166,7 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>( reference::sqrt<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'"); case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
...@@ -1166,17 +1174,17 @@ private: ...@@ -1166,17 +1174,17 @@ private:
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]), reference::subtract<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
const op::Sum* sum = static_cast<const op::Sum*>(&node); const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(static_cast<const T*>(args[0]), reference::sum<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
sum->get_reduction_axes()); sum->get_reduction_axes());
...@@ -1186,14 +1194,14 @@ private: ...@@ -1186,14 +1194,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>( reference::tan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Tanh: case OP_TYPEID::Tanh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>( reference::tanh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
...@@ -1201,9 +1209,9 @@ private: ...@@ -1201,9 +1209,9 @@ private:
const op::TopK* topk = static_cast<const op::TopK*>(&node); const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64) if (node.get_output_element_type(0) == element::i64)
{ {
reference::topk<T, int64_t>(static_cast<const T*>(args[0]), reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
...@@ -1212,9 +1220,9 @@ private: ...@@ -1212,9 +1220,9 @@ private:
} }
else if (node.get_output_element_type(0) == element::i32) else if (node.get_output_element_type(0) == element::i32)
{ {
reference::topk<T, int32_t>(static_cast<const T*>(args[0]), reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment