Unverified Commit 17bcff88 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into bob/unit-test

parents c5a7e690 f4d44bbc
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -59,6 +60,7 @@ ...@@ -59,6 +60,7 @@
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp" #include "ngraph/runtime/reference/sum.hpp"
...@@ -509,7 +511,8 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -509,7 +511,8 @@ void pass::ConstantFolding::construct_constant_binary()
bool is_supported_unary_op(std::shared_ptr<Node> n) bool is_supported_unary_op(std::shared_ptr<Node> n)
{ {
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) || return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sqrt>(n); std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sign>(n) ||
std::dynamic_pointer_cast<op::Sqrt>(n);
} }
template <class T> template <class T>
...@@ -521,7 +524,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -521,7 +524,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
if (std::dynamic_pointer_cast<op::Sqrt>(unary)) if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{ {
std::vector<T> values{constant->get_vector<T>()}; std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < 0; })) if (std::any_of(values.begin(), values.end(), [](T i) { return i < T(0); }))
{ {
throw ngraph_error("Square root of negative value"); throw ngraph_error("Square root of negative value");
} }
...@@ -556,6 +559,11 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -556,6 +559,11 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::relu<T>( runtime::reference::relu<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Sign>(unary))
{
runtime::reference::sign<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Sqrt>(unary)) else if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{ {
runtime::reference::sqrt<T>( runtime::reference::sqrt<T>(
...@@ -603,33 +611,59 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -603,33 +611,59 @@ void pass::ConstantFolding::construct_constant_unary()
func = handler->second(unary_match.get()); func = handler->second(unary_match.get());
} }
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) switch (type.get_type_enum())
{
replace_node(m.get_match_root(),
fold_constant_unary<int>(constant_match, unary_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_unary<int8_t>(constant_match, unary_match, func));
return true;
}
else if (type == element::f32)
{ {
replace_node(m.get_match_root(), case element::Type_t::undefined:
fold_constant_unary<float>(constant_match, unary_match, func)); NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_unary_callback");
return true; break;
} case element::Type_t::dynamic:
else if (type == element::f64) NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_unary_callback");
{ break;
replace_node(m.get_match_root(), case element::Type_t::boolean:
fold_constant_unary<double>(constant_match, unary_match, func)); replacement = fold_constant_unary<char>(constant_match, unary_match, func);
return true; break;
case element::Type_t::bf16:
replacement = fold_constant_unary<bfloat16>(constant_match, unary_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_unary<float16>(constant_match, unary_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_unary<float>(constant_match, unary_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_unary<double>(constant_match, unary_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_unary<int8_t>(constant_match, unary_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_unary<int16_t>(constant_match, unary_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_unary<int32_t>(constant_match, unary_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_unary<int64_t>(constant_match, unary_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_unary<uint8_t>(constant_match, unary_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_unary<uint16_t>(constant_match, unary_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_unary<uint32_t>(constant_match, unary_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_unary<uint64_t>(constant_match, unary_match, func);
break;
} }
return false; replace_node(m.get_match_root(), replacement);
return true;
}; };
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary"); auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
......
...@@ -470,6 +470,12 @@ namespace ngraph ...@@ -470,6 +470,12 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt); BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt);
} }
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Sign)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sign);
}
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher() BuildOpMap& GetGlobalBuildDispatcher()
...@@ -536,6 +542,7 @@ namespace ngraph ...@@ -536,6 +542,7 @@ namespace ngraph
REGISTER_CF_BUILDER(Negative); REGISTER_CF_BUILDER(Negative);
REGISTER_CF_BUILDER(Relu); REGISTER_CF_BUILDER(Relu);
REGISTER_CF_BUILDER(Sqrt); REGISTER_CF_BUILDER(Sqrt);
REGISTER_CF_BUILDER(Sign);
} }
} }
} }
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
for (size_t i = 0; i < count; i++) for (size_t i = 0; i < count; i++)
{ {
// TODO: generic "abs" doesn't work here for some reason. // TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < 0 ? -arg[i] : arg[i]); out[i] = (arg[i] < T(0) ? T(-arg[i]) : arg[i]);
} }
} }
} }
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
for (size_t i = 0; i < count; i++) for (size_t i = 0; i < count; i++)
{ {
out[i] = (arg[i] < 0 ? -1 : (arg[i] > 0 ? 1 : 0)); out[i] = (arg[i] < T(0) ? T(-1) : (arg[i] > T(0) ? T(1) : T(0)));
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment