Unverified Commit 5f0a9f96 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add templated get_data_ptr() methods to HostTensorView and Constant to make…

add templated get_data_ptr() methods to HostTensorView and Constant to make using them a little cleaner. (#924)
parent 30d24597
...@@ -154,6 +154,12 @@ namespace ngraph ...@@ -154,6 +154,12 @@ namespace ngraph
} }
const void* get_data_ptr() const { return m_data; } const void* get_data_ptr() const { return m_data; }
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<T*>(m_data);
}
bool is_constant() const override { return true; } bool is_constant() const override { return true; }
protected: protected:
template <typename T> template <typename T>
......
...@@ -46,6 +46,18 @@ public: ...@@ -46,6 +46,18 @@ public:
char* get_data_ptr(); char* get_data_ptr();
const char* get_data_ptr() const; const char* get_data_ptr() const;
template <typename T>
T* get_data_ptr()
{
return reinterpret_cast<T*>(get_data_ptr());
}
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<T*>(get_data_ptr());
}
size_t get_size() const; size_t get_size() const;
const element::Type& get_element_type() const; const element::Type& get_element_type() const;
......
...@@ -298,7 +298,7 @@ void runtime::interpreter::INTBackend::perform_nan_check( ...@@ -298,7 +298,7 @@ void runtime::interpreter::INTBackend::perform_nan_check(
const element::Type& type = tv->get_tensor().get_element_type(); const element::Type& type = tv->get_tensor().get_element_type();
if (type == element::f32) if (type == element::f32)
{ {
const float* data = reinterpret_cast<float*>(tv->get_data_ptr()); const float* data = tv->get_data_ptr<float>();
for (size_t i = 0; i < tv->get_element_count(); i++) for (size_t i = 0; i < tv->get_element_count(); i++)
{ {
if (std::isnan(data[i])) if (std::isnan(data[i]))
...@@ -317,7 +317,7 @@ void runtime::interpreter::INTBackend::perform_nan_check( ...@@ -317,7 +317,7 @@ void runtime::interpreter::INTBackend::perform_nan_check(
} }
else if (type == element::f64) else if (type == element::f64)
{ {
const double* data = reinterpret_cast<double*>(tv->get_data_ptr()); const double* data = tv->get_data_ptr<double>();
for (size_t i = 0; i < tv->get_element_count(); i++) for (size_t i = 0; i < tv->get_element_count(); i++)
{ {
if (std::isnan(data[i])) if (std::isnan(data[i]))
......
...@@ -171,57 +171,53 @@ private: ...@@ -171,57 +171,53 @@ private:
std::string node_op = node.description(); std::string node_op = node.description();
if (node_op == "Abs") if (node_op == "Abs")
{ {
reference::abs<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::abs<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Acos") else if (node_op == "Acos")
{ {
reference::acos<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::acos<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Add") else if (node_op == "Add")
{ {
reference::add<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::add<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
else if (node_op == "AllReduce") else if (node_op == "AllReduce")
{ {
reference::allreduce<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::allreduce<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_element_type(), args[0]->get_element_type(),
static_cast<int>(args[0]->get_element_count())); static_cast<int>(args[0]->get_element_count()));
} }
#endif #endif
else if (node_op == "And") else if (node_op == "And")
{ {
reference::logical_and(reinterpret_cast<char*>(args[0]->get_data_ptr()), reference::logical_and(args[0]->get_data_ptr<char>(),
reinterpret_cast<char*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<char>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Asin") else if (node_op == "Asin")
{ {
reference::asin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::asin<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Atan") else if (node_op == "Atan")
{ {
reference::atan<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::atan<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "AvgPool") else if (node_op == "AvgPool")
{ {
op::AvgPool* avg_pool = dynamic_cast<op::AvgPool*>(&node); op::AvgPool* avg_pool = dynamic_cast<op::AvgPool*>(&node);
reference::avg_pool<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::avg_pool<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
avg_pool->get_window_shape(), avg_pool->get_window_shape(),
...@@ -233,8 +229,8 @@ private: ...@@ -233,8 +229,8 @@ private:
else if (node_op == "AvgPoolBackprop") else if (node_op == "AvgPoolBackprop")
{ {
op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node); op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::avg_pool_backprop<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
apb->get_window_shape(), apb->get_window_shape(),
...@@ -249,17 +245,16 @@ private: ...@@ -249,17 +245,16 @@ private:
Shape in_shape = args[0]->get_shape(); Shape in_shape = args[0]->get_shape();
Shape out_shape = out[0]->get_shape(); Shape out_shape = out[0]->get_shape();
AxisSet broadcast_axes = broadcast->get_broadcast_axes(); AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::broadcast<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
in_shape, in_shape,
out_shape, out_shape,
broadcast_axes); broadcast_axes);
} }
else if (node_op == "Ceiling") else if (node_op == "Ceiling")
{ {
reference::ceiling<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::ceiling<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Concat") else if (node_op == "Concat")
{ {
...@@ -268,11 +263,11 @@ private: ...@@ -268,11 +263,11 @@ private:
std::vector<Shape> in_shapes; std::vector<Shape> in_shapes;
for (std::shared_ptr<HostTensorView> arg : args) for (std::shared_ptr<HostTensorView> arg : args)
{ {
in_args.push_back(reinterpret_cast<T*>(arg->get_data_ptr())); in_args.push_back(arg->get_data_ptr<T>());
in_shapes.push_back(arg->get_shape()); in_shapes.push_back(arg->get_shape());
} }
reference::concat<T>(in_args, reference::concat<T>(in_args,
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
in_shapes, in_shapes,
out[0]->get_shape(), out[0]->get_shape(),
concat->get_concatenation_axis()); concat->get_concatenation_axis());
...@@ -280,9 +275,8 @@ private: ...@@ -280,9 +275,8 @@ private:
else if (node_op == "Constant") else if (node_op == "Constant")
{ {
const op::Constant* c = static_cast<const op::Constant*>(&node); const op::Constant* c = static_cast<const op::Constant*>(&node);
reference::constant<T>(reinterpret_cast<const T*>(c->get_data_ptr()), reference::constant<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Convert") else if (node_op == "Convert")
{ {
...@@ -290,68 +284,68 @@ private: ...@@ -290,68 +284,68 @@ private:
element::Type type = node.get_element_type(); element::Type type = node.get_element_type();
if (type == element::boolean) if (type == element::boolean)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::f32) else if (type == element::f32)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<float*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<float>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<double*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<double>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::i8) else if (type == element::i8)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<int8_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<int8_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::i16) else if (type == element::i16)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<int16_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<int16_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<int32_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<int32_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::i64) else if (type == element::i64)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<int64_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<int64_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::u8) else if (type == element::u8)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<uint8_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<uint8_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::u16) else if (type == element::u16)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<uint16_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<uint16_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::u32) else if (type == element::u32)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<uint32_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<uint32_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (type == element::u64) else if (type == element::u64)
{ {
reference::convert<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convert<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<uint64_t*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<uint64_t>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else else
...@@ -364,9 +358,9 @@ private: ...@@ -364,9 +358,9 @@ private:
else if (node_op == "Convolution") else if (node_op == "Convolution")
{ {
auto c = static_cast<const op::Convolution*>(&node); auto c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convolution<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
...@@ -386,9 +380,9 @@ private: ...@@ -386,9 +380,9 @@ private:
else if (node_op == "ConvolutionBackpropFilters") else if (node_op == "ConvolutionBackpropFilters")
{ {
auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node); auto c = static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::convolution<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
...@@ -409,9 +403,9 @@ private: ...@@ -409,9 +403,9 @@ private:
{ {
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
auto c = static_cast<const op::ConvolutionBackpropData*>(&node); auto c = static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(reinterpret_cast<T*>(args[1]->get_data_ptr()), reference::convolution<T>(args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[0]->get_data_ptr()), args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[1]->get_shape(), args[1]->get_shape(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
...@@ -430,30 +424,28 @@ private: ...@@ -430,30 +424,28 @@ private:
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
reference::cos<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::cos<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Cosh") else if (node_op == "Cosh")
{ {
reference::cosh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::cosh<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Divide") else if (node_op == "Divide")
{ {
reference::divide<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::divide<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Dot") else if (node_op == "Dot")
{ {
op::Dot* dot = dynamic_cast<op::Dot*>(&node); op::Dot* dot = dynamic_cast<op::Dot*>(&node);
reference::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::dot(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
...@@ -462,22 +454,20 @@ private: ...@@ -462,22 +454,20 @@ private:
else if (node_op == "Equal") else if (node_op == "Equal")
{ {
reference::equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::equal<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Exp") else if (node_op == "Exp")
{ {
reference::exp<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::exp<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Floor") else if (node_op == "Floor")
{ {
reference::floor<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::floor<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "FunctionCall") else if (node_op == "FunctionCall")
{ {
...@@ -499,60 +489,59 @@ private: ...@@ -499,60 +489,59 @@ private:
} }
else if (node_op == "Greater") else if (node_op == "Greater")
{ {
reference::greater<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::greater<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "GreaterEq") else if (node_op == "GreaterEq")
{ {
reference::greater_eq<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::greater_eq<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Less") else if (node_op == "Less")
{ {
reference::less<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::less<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "LessEq") else if (node_op == "LessEq")
{ {
reference::less_eq<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::less_eq<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Log") else if (node_op == "Log")
{ {
reference::log<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::log<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Max") else if (node_op == "Max")
{ {
const op::Max* max = static_cast<const op::Max*>(&node); const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::max<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
max->get_reduction_axes()); max->get_reduction_axes());
} }
else if (node_op == "Maximum") else if (node_op == "Maximum")
{ {
reference::maximum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::maximum<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "MaxPool") else if (node_op == "MaxPool")
{ {
op::MaxPool* max_pool = dynamic_cast<op::MaxPool*>(&node); op::MaxPool* max_pool = dynamic_cast<op::MaxPool*>(&node);
reference::max_pool<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::max_pool<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
max_pool->get_window_shape(), max_pool->get_window_shape(),
...@@ -564,9 +553,9 @@ private: ...@@ -564,9 +553,9 @@ private:
{ {
op::MaxPoolBackprop* max_pool_backprop = dynamic_cast<op::MaxPoolBackprop*>(&node); op::MaxPoolBackprop* max_pool_backprop = dynamic_cast<op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::max_pool_backprop<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
max_pool_backprop->get_window_shape(), max_pool_backprop->get_window_shape(),
...@@ -577,59 +566,58 @@ private: ...@@ -577,59 +566,58 @@ private:
else if (node_op == "Min") else if (node_op == "Min")
{ {
const op::Min* min = static_cast<const op::Min*>(&node); const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::min<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
min->get_reduction_axes()); min->get_reduction_axes());
} }
else if (node_op == "Minimum") else if (node_op == "Minimum")
{ {
reference::minimum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::minimum<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Multiply") else if (node_op == "Multiply")
{ {
reference::multiply<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::multiply<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Negative") else if (node_op == "Negative")
{ {
reference::negate<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::negate<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Not") else if (node_op == "Not")
{ {
reference::logical_not(reinterpret_cast<char*>(args[0]->get_data_ptr()), reference::logical_not(args[0]->get_data_ptr<char>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "NotEqual") else if (node_op == "NotEqual")
{ {
reference::not_equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::not_equal<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "OneHot") else if (node_op == "OneHot")
{ {
auto oh = static_cast<const op::OneHot*>(&node); auto oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::one_hot<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
oh->get_one_hot_axis()); oh->get_one_hot_axis());
} }
else if (node_op == "Or") else if (node_op == "Or")
{ {
reference::logical_or(reinterpret_cast<char*>(args[0]->get_data_ptr()), reference::logical_or(args[0]->get_data_ptr<char>(),
reinterpret_cast<char*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<char>(),
reinterpret_cast<char*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<char>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Parameter") else if (node_op == "Parameter")
...@@ -639,9 +627,9 @@ private: ...@@ -639,9 +627,9 @@ private:
{ {
op::Pad* pad = dynamic_cast<op::Pad*>(&node); op::Pad* pad = dynamic_cast<op::Pad*>(&node);
reference::pad(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::pad(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(), node.get_inputs().at(0).get_shape(),
node.get_output_shape(0), node.get_output_shape(0),
pad->get_padding_below(), pad->get_padding_below(),
...@@ -650,16 +638,16 @@ private: ...@@ -650,16 +638,16 @@ private:
} }
else if (node_op == "Power") else if (node_op == "Power")
{ {
reference::power<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::power<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Product") else if (node_op == "Product")
{ {
const op::Product* product = static_cast<const op::Product*>(&node); const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::product<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
product->get_reduction_axes()); product->get_reduction_axes());
...@@ -676,15 +664,15 @@ private: ...@@ -676,15 +664,15 @@ private:
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_temp_y"); node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_temp_y");
auto tr = std::make_shared<HostTensorView>( auto tr = std::make_shared<HostTensorView>(
node.get_output_element_type(0), Shape{}, "reduce_temp_r"); node.get_output_element_type(0), Shape{}, "reduce_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x; *(tx->get_data_ptr<T>()) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y; *(ty->get_data_ptr<T>()) = y;
call(reduction_function, {tr}, {tx, ty}); call(reduction_function, {tr}, {tx, ty});
return *(reinterpret_cast<T*>(tr->get_data_ptr())); return *(tr->get_data_ptr<T>());
}; };
reference::reduce(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reduce(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(), node.get_inputs().at(0).get_shape(),
node.get_output_shape(0), node.get_output_shape(0),
reduce->get_reduction_axes(), reduce->get_reduction_axes(),
...@@ -702,15 +690,15 @@ private: ...@@ -702,15 +690,15 @@ private:
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_window_temp_y"); node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_window_temp_y");
auto tr = std::make_shared<HostTensorView>( auto tr = std::make_shared<HostTensorView>(
node.get_output_element_type(0), Shape{}, "reduce_window_temp_r"); node.get_output_element_type(0), Shape{}, "reduce_window_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x; *(tx->get_data_ptr<T>()) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y; *(ty->get_data_ptr<T>()) = y;
call(reduction_function, {tr}, {tx, ty}); call(reduction_function, {tr}, {tx, ty});
return *(reinterpret_cast<T*>(tr->get_data_ptr())); return *(tr->get_data_ptr<T>());
}; };
reference::reduce_window(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reduce_window(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(), node.get_inputs().at(0).get_shape(),
node.get_output_shape(0), node.get_output_shape(0),
f, f,
...@@ -719,23 +707,22 @@ private: ...@@ -719,23 +707,22 @@ private:
} }
else if (node_op == "Relu") else if (node_op == "Relu")
{ {
reference::relu<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::relu<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "ReluBackprop") else if (node_op == "ReluBackprop")
{ {
reference::relu_backprop<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::relu_backprop<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "ReplaceSlice") else if (node_op == "ReplaceSlice")
{ {
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node); const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::replace_slice<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[1]->get_shape(), args[1]->get_shape(),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -745,8 +732,8 @@ private: ...@@ -745,8 +732,8 @@ private:
else if (node_op == "Reshape") else if (node_op == "Reshape")
{ {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(&node); op::Reshape* reshape = dynamic_cast<op::Reshape*>(&node);
reference::reshape(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reshape(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
reshape->get_input_order(), reshape->get_input_order(),
out[0]->get_shape()); out[0]->get_shape());
...@@ -754,25 +741,25 @@ private: ...@@ -754,25 +741,25 @@ private:
else if (node_op == "Result") else if (node_op == "Result")
{ {
op::Result* res = dynamic_cast<op::Result*>(&node); op::Result* res = dynamic_cast<op::Result*>(&node);
reference::result(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::result(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
shape_size(res->get_shape())); shape_size(res->get_shape()));
} }
else if (node_op == "Reverse") else if (node_op == "Reverse")
{ {
op::Reverse* reverse = dynamic_cast<op::Reverse*>(&node); op::Reverse* reverse = dynamic_cast<op::Reverse*>(&node);
reference::reverse(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::reverse(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
reverse->get_reversed_axes()); reverse->get_reversed_axes());
} }
else if (node_op == "Select") else if (node_op == "Select")
{ {
reference::select<T>(reinterpret_cast<char*>(args[0]->get_data_ptr()), reference::select<T>(args[0]->get_data_ptr<char>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[2]->get_data_ptr()), args[2]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "SelectAndScatter") else if (node_op == "SelectAndScatter")
...@@ -790,10 +777,10 @@ private: ...@@ -790,10 +777,10 @@ private:
node.get_inputs().at(1).get_element_type(), Shape{}, "selection_temp_y"); node.get_inputs().at(1).get_element_type(), Shape{}, "selection_temp_y");
auto tr = std::make_shared<runtime::HostTensorView>( auto tr = std::make_shared<runtime::HostTensorView>(
element::boolean, Shape{}, "selection_temp_r"); element::boolean, Shape{}, "selection_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x; *(tx->get_data_ptr<T>()) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y; *(ty->get_data_ptr<T>()) = y;
call(selection_function, {tr}, {tx, ty}); call(selection_function, {tr}, {tx, ty});
return *(reinterpret_cast<char*>(tr->get_data_ptr())); return *(tr->get_data_ptr<char>());
}; };
std::shared_ptr<ngraph::Function> scatter_function = std::shared_ptr<ngraph::Function> scatter_function =
...@@ -805,16 +792,16 @@ private: ...@@ -805,16 +792,16 @@ private:
node.get_inputs().at(1).get_element_type(), Shape{}, "scatter_temp_y"); node.get_inputs().at(1).get_element_type(), Shape{}, "scatter_temp_y");
auto tr = std::make_shared<runtime::HostTensorView>( auto tr = std::make_shared<runtime::HostTensorView>(
node.get_output_element_type(0), Shape{}, "scatter_temp_r"); node.get_output_element_type(0), Shape{}, "scatter_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x; *(tx->get_data_ptr<T>()) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y; *(ty->get_data_ptr<T>()) = y;
call(scatter_function, {tr}, {tx, ty}); call(scatter_function, {tr}, {tx, ty});
return *(reinterpret_cast<T*>(tr->get_data_ptr())); return *(tr->get_data_ptr<T>());
}; };
reference::select_and_scatter<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::select_and_scatter<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[2]->get_data_ptr()), args[2]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
...@@ -825,27 +812,24 @@ private: ...@@ -825,27 +812,24 @@ private:
} }
else if (node_op == "Sign") else if (node_op == "Sign")
{ {
reference::sign<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::sign<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Sin") else if (node_op == "Sin")
{ {
reference::sin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::sin<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Sinh") else if (node_op == "Sinh")
{ {
reference::sinh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::sinh<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Slice") else if (node_op == "Slice")
{ {
const op::Slice* slice = static_cast<const op::Slice*>(&node); const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::slice<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -855,44 +839,41 @@ private: ...@@ -855,44 +839,41 @@ private:
else if (node_op == "Softmax") else if (node_op == "Softmax")
{ {
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node); const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::softmax<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_shape(), out[0]->get_shape(),
softmax->get_axes()); softmax->get_axes());
} }
else if (node_op == "Sqrt") else if (node_op == "Sqrt")
{ {
reference::sqrt<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::sqrt<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Subtract") else if (node_op == "Subtract")
{ {
reference::subtract<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::subtract<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(args[1]->get_data_ptr()), args[1]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Sum") else if (node_op == "Sum")
{ {
const op::Sum* sum = static_cast<const op::Sum*>(&node); const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::sum<T>(args[0]->get_data_ptr<T>(),
reinterpret_cast<T*>(out[0]->get_data_ptr()), out[0]->get_data_ptr<T>(),
args[0]->get_shape(), args[0]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
sum->get_reduction_axes()); sum->get_reduction_axes());
} }
else if (node_op == "Tan") else if (node_op == "Tan")
{ {
reference::tan<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::tan<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "Tanh") else if (node_op == "Tanh")
{ {
reference::tanh<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), reference::tanh<T>(
reinterpret_cast<T*>(out[0]->get_data_ptr()), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else else
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment