Commit a7c5eb01 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

make logical ops input type aware (#1203)

parent f1ebcd3e
......@@ -202,9 +202,9 @@ private:
#endif
else if (node_op == "And")
{
reference::logical_and(args[0]->get_data_ptr<char>(),
args[1]->get_data_ptr<char>(),
out[0]->get_data_ptr<char>(),
reference::logical_and(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
}
else if (node_op == "Asin")
......@@ -633,9 +633,8 @@ private:
}
else if (node_op == "Not")
{
reference::logical_not(args[0]->get_data_ptr<char>(),
out[0]->get_data_ptr<char>(),
out[0]->get_element_count());
reference::logical_not(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
}
else if (node_op == "NotEqual")
{
......@@ -655,9 +654,9 @@ private:
}
else if (node_op == "Or")
{
reference::logical_or(args[0]->get_data_ptr<char>(),
args[1]->get_data_ptr<char>(),
out[0]->get_data_ptr<char>(),
reference::logical_or(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
out[0]->get_element_count());
}
else if (node_op == "Parameter")
......
......@@ -24,12 +24,12 @@ namespace ngraph
{
namespace reference
{
static inline void
logical_and(const char* arg0, const char* arg1, char* out, size_t count)
template <typename T>
void logical_and(const T* arg0, const T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] && arg1[i];
out[i] = static_cast<T>(arg0[i] && arg1[i]);
}
}
}
......
......@@ -24,14 +24,12 @@ namespace ngraph
{
namespace reference
{
static inline void
logical_not(const char* arg,
char* out,
size_t count) // TODO: using char for bool, is this right?
template <typename T>
void logical_not(const T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = !(arg[i]);
out[i] = static_cast<T>(!(arg[i]));
}
}
}
......
......@@ -24,12 +24,12 @@ namespace ngraph
{
namespace reference
{
static inline void
logical_or(const char* arg0, const char* arg1, char* out, size_t count)
template <typename T>
void logical_or(const T* arg0, const T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] || arg1[i];
out[i] = static_cast<T>(arg0[i] || arg1[i]);
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment