Unverified Commit 253d4cdf authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Pass HostTensor all the way down (#2454)

* pass HostTensor all the way down

* fix distributed build error
parent 0a3858a0
......@@ -182,22 +182,11 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
return true;
}
void runtime::interpreter::INTExecutable::generate_calls(
const element::Type& type,
const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs,
const vector<shared_ptr<HostTensor>>& inputs)
void runtime::interpreter::INTExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& out,
const vector<shared_ptr<HostTensor>>& in)
{
vector<void*> out;
vector<const void*> in;
for (auto t : outputs)
{
out.push_back(t->get_data_ptr());
}
for (auto t : inputs)
{
in.push_back(t->get_data_ptr());
}
stringstream ss;
switch (type.get_type_enum())
{
......
......@@ -181,8 +181,8 @@ private:
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out,
const std::vector<const void*>& args)
const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
......@@ -200,30 +200,30 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Acos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Add:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::add<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::All:
{
const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
reference::all(args[0]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_output_shape(0),
all->get_reduction_axes());
......@@ -231,8 +231,8 @@ private:
}
case OP_TYPEID::AllReduce: {
#ifdef NGRAPH_DISTRIBUTED_ENABLE
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])),
static_cast<T*>(out[0]),
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))));
#endif
......@@ -241,17 +241,17 @@ private:
case OP_TYPEID::And:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::logical_and(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Any:
{
const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]),
static_cast<char*>(out[0]),
reference::any(args[0]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_output_shape(0),
any->get_reduction_axes());
......@@ -263,16 +263,16 @@ private:
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
reference::argmin<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
}
else if (element_type == element::i32)
{
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
reference::argmin<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmin->get_reduction_axis());
......@@ -289,16 +289,16 @@ private:
auto element_type = node.get_output_element_type(0);
if (element_type == element::i64)
{
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
reference::argmax<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
}
else if (element_type == element::i32)
{
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
reference::argmax<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_output_shape(0),
argmax->get_reduction_axis());
......@@ -313,22 +313,22 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Atan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::AvgPool:
{
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
avg_pool->get_window_shape(),
......@@ -347,11 +347,10 @@ private:
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability()));
}
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]);
bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0));
reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training);
reference::generate_mask<T>(out[0]->get_data_ptr<T>(), element_count, state, training);
break;
}
case OP_TYPEID::GetOutputElement:
......@@ -361,7 +360,7 @@ private:
size_t n = get_output_element->get_n();
size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes);
std::memcpy(out[0]->get_data_ptr<T>(), args[n]->get_data_ptr<T>(), num_bytes);
break;
}
case OP_TYPEID::BatchNormTraining:
......@@ -369,12 +368,12 @@ private:
const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
out[2]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
......@@ -383,12 +382,12 @@ private:
const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<T*>(out[0]),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const T>(),
args[4]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
......@@ -397,23 +396,23 @@ private:
const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<const T*>(args[5]),
static_cast<T*>(out[0]),
static_cast<T*>(out[1]),
static_cast<T*>(out[2]),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const T>(),
args[4]->get_data_ptr<const T>(),
args[5]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
out[2]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
case OP_TYPEID::AvgPoolBackprop:
{
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::avg_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
apb->get_window_shape(),
......@@ -429,8 +428,8 @@ private:
Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::broadcast<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
in_shape,
out_shape,
broadcast_axes);
......@@ -441,7 +440,7 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Concat:
......@@ -451,11 +450,11 @@ private:
std::vector<Shape> in_shapes;
for (size_t i = 0; i < node.get_input_size(); i++)
{
in_args.push_back(static_cast<const T*>(args[i]));
in_args.push_back(args[i]->get_data_ptr<const T>());
in_shapes.push_back(node.get_input_shape(i));
}
reference::concat<T>(in_args,
static_cast<T*>(out[0]),
out[0]->get_data_ptr<T>(),
in_shapes,
node.get_output_shape(0),
concat->get_concatenation_axis());
......@@ -465,7 +464,7 @@ private:
{
const op::Constant* c = static_cast<const op::Constant*>(&node);
size_t element_count = shape_size(node.get_output_shape(0));
reference::constant<T>(c->get_data_ptr<T>(), static_cast<T*>(out[0]), element_count);
reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ScalarConstantLike: break;
......@@ -479,47 +478,56 @@ private:
{
case element::Type_t::boolean:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
break;
case element::Type_t::f32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
break;
case element::Type_t::f64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<double>(),
element_count);
break;
case element::Type_t::i8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int8_t>(),
element_count);
break;
case element::Type_t::i16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int16_t>(),
element_count);
break;
case element::Type_t::i32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
element_count);
break;
case element::Type_t::i64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
element_count);
break;
case element::Type_t::u8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint8_t>(),
element_count);
break;
case element::Type_t::u16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint16_t>(),
element_count);
break;
case element::Type_t::u32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint32_t>(),
element_count);
break;
case element::Type_t::u64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count);
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint64_t>(),
element_count);
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
......@@ -532,9 +540,9 @@ private:
case OP_TYPEID::Convolution:
{
const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::convolution<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
......@@ -556,9 +564,9 @@ private:
{
const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::convolution<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
......@@ -581,9 +589,9 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(static_cast<const T*>(args[1]),
static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::convolution<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
node.get_input_shape(0),
node.get_output_shape(0),
......@@ -605,14 +613,14 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Cosh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Dequantize:
......@@ -622,20 +630,20 @@ private:
if (type == element::f32)
{
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const float*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<float*>(out[0]),
reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const float>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<float>(),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else if (type == element::f64)
{
reference::dequantize<T>(static_cast<const T*>(args[0]),
static_cast<const double*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<double*>(out[0]),
reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const double>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<double>(),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
......@@ -652,9 +660,9 @@ private:
case OP_TYPEID::Divide:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::divide<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -662,9 +670,9 @@ private:
{
const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::dot(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
......@@ -679,33 +687,33 @@ private:
if (type == element::f32)
{
reference::embedding<T, float>(static_cast<const float*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, float>(args[0]->get_data_ptr<const float>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
else if (type == element::f64)
{
reference::embedding<T, double>(static_cast<const double*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, double>(args[0]->get_data_ptr<const double>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
else if (type == element::i32)
{
reference::embedding<T, int>(static_cast<const int*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, int>(args[0]->get_data_ptr<const int>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
else if (type == element::i64)
{
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::embedding<T, int64_t>(args[0]->get_data_ptr<const int64_t>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
embed->get_shape());
}
......@@ -719,9 +727,9 @@ private:
case OP_TYPEID::Equal:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::equal<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
......@@ -729,49 +737,49 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Floor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Greater:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::greater<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::GreaterEq:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::greater_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::Less:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::less<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::LessEq:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
......@@ -779,14 +787,14 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::lrn<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
lrn->get_alpha(),
lrn->get_beta(),
......@@ -797,8 +805,8 @@ private:
case OP_TYPEID::Max:
{
const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::max<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
max->get_reduction_axes());
......@@ -807,9 +815,9 @@ private:
case OP_TYPEID::Maximum:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::maximum<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -817,8 +825,8 @@ private:
{
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::max_pool<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
max_pool->get_window_shape(),
......@@ -832,9 +840,9 @@ private:
const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::max_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
node.get_output_shape(0),
max_pool_backprop->get_window_shape(),
......@@ -846,8 +854,8 @@ private:
case OP_TYPEID::Min:
{
const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::min<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
min->get_reduction_axes());
......@@ -856,18 +864,18 @@ private:
case OP_TYPEID::Minimum:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::minimum<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Multiply:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::multiply<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -875,30 +883,30 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::NotEqual:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<char*>(out[0]),
reference::not_equal<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
element_count);
break;
}
case OP_TYPEID::OneHot:
{
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
oh->get_one_hot_axis());
......@@ -907,9 +915,9 @@ private:
case OP_TYPEID::Or:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::logical_or(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -918,9 +926,9 @@ private:
{
const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::pad(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
pad->get_padding_below(),
......@@ -931,17 +939,17 @@ private:
case OP_TYPEID::Power:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::power<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Product:
{
const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::product<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
product->get_reduction_axes());
......@@ -954,10 +962,10 @@ private:
if (type == element::u8)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const uint8_t*>(args[2]),
static_cast<uint8_t*>(out[0]),
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
......@@ -965,10 +973,10 @@ private:
}
else if (type == element::i8)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int8_t*>(args[2]),
static_cast<int8_t*>(out[0]),
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
......@@ -976,10 +984,10 @@ private:
}
else if (type == element::i32)
{
reference::quantize<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const int32_t*>(args[2]),
static_cast<int32_t*>(out[0]),
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const int32_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
......@@ -1009,24 +1017,24 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ReluBackprop:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::relu_backprop<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::ReplaceSlice:
{
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
......@@ -1037,8 +1045,8 @@ private:
case OP_TYPEID::Reshape:
{
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
reference::reshape(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::reshape(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
reshape->get_input_order(),
node.get_output_shape(0));
......@@ -1047,16 +1055,16 @@ private:
case OP_TYPEID::Result:
{
const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::result(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
shape_size(res->get_shape()));
break;
}
case OP_TYPEID::Reverse:
{
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::reverse(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
reverse->get_reversed_axes());
......@@ -1068,12 +1076,12 @@ private:
if (node.get_input_element_type(1) == element::i32)
{
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
reverse->get_batch_axis(),
reverse->get_sequence_axis(),
static_cast<const int32_t*>(args[1]));
args[1]->get_data_ptr<const int32_t>());
}
else
{
......@@ -1084,31 +1092,31 @@ private:
case OP_TYPEID::Select:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<T*>(out[0]),
reference::select<T>(args[0]->get_data_ptr<const char>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::ShapeOf:
{
reference::shape_of(node.get_input_shape(0), static_cast<uint64_t*>(out[0]));
reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
break;
}
case OP_TYPEID::Sigmoid:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::SigmoidBackprop:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::sigmoid_backprop<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
......@@ -1116,28 +1124,28 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sinh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Slice:
{
const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::slice<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
......@@ -1148,8 +1156,8 @@ private:
case OP_TYPEID::Softmax:
{
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::softmax<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_output_shape(0),
softmax->get_axes());
break;
......@@ -1158,7 +1166,7 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
......@@ -1166,17 +1174,17 @@ private:
case OP_TYPEID::Subtract:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<T*>(out[0]),
reference::subtract<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Sum:
{
const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(static_cast<const T*>(args[0]),
static_cast<T*>(out[0]),
reference::sum<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
sum->get_reduction_axes());
......@@ -1186,14 +1194,14 @@ private:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Tanh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count);
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::TopK:
......@@ -1201,9 +1209,9 @@ private:
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64)
{
reference::topk<T, int64_t>(static_cast<const T*>(args[0]),
static_cast<int64_t*>(out[0]),
static_cast<T*>(out[1]),
reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
out[1]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
......@@ -1212,9 +1220,9 @@ private:
}
else if (node.get_output_element_type(0) == element::i32)
{
reference::topk<T, int32_t>(static_cast<const T*>(args[0]),
static_cast<int32_t*>(out[0]),
static_cast<T*>(out[1]),
reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
out[1]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment