Unverified Commit 5f0a9f96 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add templated get_data_ptr() methods to HostTensorView and Constant to make…

add templated get_data_ptr() methods to HostTensorView and Constant to make using them a little cleaner. (#924)
parent 30d24597
...@@ -154,6 +154,12 @@ namespace ngraph ...@@ -154,6 +154,12 @@ namespace ngraph
} }
const void* get_data_ptr() const { return m_data; } const void* get_data_ptr() const { return m_data; }
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<T*>(m_data);
}
bool is_constant() const override { return true; } bool is_constant() const override { return true; }
protected: protected:
template <typename T> template <typename T>
......
...@@ -46,6 +46,18 @@ public: ...@@ -46,6 +46,18 @@ public:
char* get_data_ptr(); char* get_data_ptr();
const char* get_data_ptr() const; const char* get_data_ptr() const;
template <typename T>
T* get_data_ptr()
{
return reinterpret_cast<T*>(get_data_ptr());
}
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<T*>(get_data_ptr());
}
size_t get_size() const; size_t get_size() const;
const element::Type& get_element_type() const; const element::Type& get_element_type() const;
......
...@@ -298,7 +298,7 @@ void runtime::interpreter::INTBackend::perform_nan_check( ...@@ -298,7 +298,7 @@ void runtime::interpreter::INTBackend::perform_nan_check(
const element::Type& type = tv->get_tensor().get_element_type(); const element::Type& type = tv->get_tensor().get_element_type();
if (type == element::f32) if (type == element::f32)
{ {
const float* data = reinterpret_cast<float*>(tv->get_data_ptr()); const float* data = tv->get_data_ptr<float>();
for (size_t i = 0; i < tv->get_element_count(); i++) for (size_t i = 0; i < tv->get_element_count(); i++)
{ {
if (std::isnan(data[i])) if (std::isnan(data[i]))
...@@ -317,7 +317,7 @@ void runtime::interpreter::INTBackend::perform_nan_check( ...@@ -317,7 +317,7 @@ void runtime::interpreter::INTBackend::perform_nan_check(
} }
else if (type == element::f64) else if (type == element::f64)
{ {
const double* data = reinterpret_cast<double*>(tv->get_data_ptr()); const double* data = tv->get_data_ptr<double>();
for (size_t i = 0; i < tv->get_element_count(); i++) for (size_t i = 0; i < tv->get_element_count(); i++)
{ {
if (std::isnan(data[i])) if (std::isnan(data[i]))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment