Unverified Commit 5f0a9f96 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add templated get_data_ptr() methods to HostTensorView and Constant to make…

add templated get_data_ptr() methods to HostTensorView and Constant to make using them a little cleaner. (#924)
parent 30d24597
......@@ -154,6 +154,12 @@ namespace ngraph
}
const void* get_data_ptr() const { return m_data; }
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<T*>(m_data);
}
bool is_constant() const override { return true; }
protected:
template <typename T>
......
......@@ -46,6 +46,18 @@ public:
char* get_data_ptr();
const char* get_data_ptr() const;
template <typename T>
T* get_data_ptr()
{
return reinterpret_cast<T*>(get_data_ptr());
}
template <typename T>
const T* get_data_ptr() const
{
return reinterpret_cast<T*>(get_data_ptr());
}
size_t get_size() const;
const element::Type& get_element_type() const;
......
......@@ -298,7 +298,7 @@ void runtime::interpreter::INTBackend::perform_nan_check(
const element::Type& type = tv->get_tensor().get_element_type();
if (type == element::f32)
{
const float* data = reinterpret_cast<float*>(tv->get_data_ptr());
const float* data = tv->get_data_ptr<float>();
for (size_t i = 0; i < tv->get_element_count(); i++)
{
if (std::isnan(data[i]))
......@@ -317,7 +317,7 @@ void runtime::interpreter::INTBackend::perform_nan_check(
}
else if (type == element::f64)
{
const double* data = reinterpret_cast<double*>(tv->get_data_ptr());
const double* data = tv->get_data_ptr<double>();
for (size_t i = 0; i < tv->get_element_count(); i++)
{
if (std::isnan(data[i]))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment