Commit a7c5eb01 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

make logical ops input type aware (#1203)

parent f1ebcd3e
...@@ -202,9 +202,9 @@ private: ...@@ -202,9 +202,9 @@ private:
#endif #endif
else if (node_op == "And") else if (node_op == "And")
{ {
reference::logical_and(args[0]->get_data_ptr<char>(), reference::logical_and(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<char>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Asin") else if (node_op == "Asin")
...@@ -633,9 +633,8 @@ private: ...@@ -633,9 +633,8 @@ private:
} }
else if (node_op == "Not") else if (node_op == "Not")
{ {
reference::logical_not(args[0]->get_data_ptr<char>(), reference::logical_not(
out[0]->get_data_ptr<char>(), args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
out[0]->get_element_count());
} }
else if (node_op == "NotEqual") else if (node_op == "NotEqual")
{ {
...@@ -655,9 +654,9 @@ private: ...@@ -655,9 +654,9 @@ private:
} }
else if (node_op == "Or") else if (node_op == "Or")
{ {
reference::logical_or(args[0]->get_data_ptr<char>(), reference::logical_or(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<char>(), args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<T>(),
out[0]->get_element_count()); out[0]->get_element_count());
} }
else if (node_op == "Parameter") else if (node_op == "Parameter")
......
...@@ -24,12 +24,12 @@ namespace ngraph ...@@ -24,12 +24,12 @@ namespace ngraph
{ {
namespace reference namespace reference
{ {
static inline void template <typename T>
logical_and(const char* arg0, const char* arg1, char* out, size_t count) void logical_and(const T* arg0, const T* arg1, T* out, size_t count)
{ {
for (size_t i = 0; i < count; i++) for (size_t i = 0; i < count; i++)
{ {
out[i] = arg0[i] && arg1[i]; out[i] = static_cast<T>(arg0[i] && arg1[i]);
} }
} }
} }
......
...@@ -24,14 +24,12 @@ namespace ngraph ...@@ -24,14 +24,12 @@ namespace ngraph
{ {
namespace reference namespace reference
{ {
static inline void template <typename T>
logical_not(const char* arg, void logical_not(const T* arg, T* out, size_t count)
char* out,
size_t count) // TODO: using char for bool, is this right?
{ {
for (size_t i = 0; i < count; i++) for (size_t i = 0; i < count; i++)
{ {
out[i] = !(arg[i]); out[i] = static_cast<T>(!(arg[i]));
} }
} }
} }
......
...@@ -24,12 +24,12 @@ namespace ngraph ...@@ -24,12 +24,12 @@ namespace ngraph
{ {
namespace reference namespace reference
{ {
static inline void template <typename T>
logical_or(const char* arg0, const char* arg1, char* out, size_t count) void logical_or(const T* arg0, const T* arg1, T* out, size_t count)
{ {
for (size_t i = 0; i < count; i++) for (size_t i = 0; i < count; i++)
{ {
out[i] = arg0[i] || arg1[i]; out[i] = static_cast<T>(arg0[i] || arg1[i]);
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment