Unverified Commit 17bcff88 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into bob/unit-test

parents c5a7e690 f4d44bbc
......@@ -37,6 +37,7 @@
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
......@@ -59,6 +60,7 @@
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
......@@ -509,7 +511,8 @@ void pass::ConstantFolding::construct_constant_binary()
bool is_supported_unary_op(std::shared_ptr<Node> n)
{
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sign>(n) ||
std::dynamic_pointer_cast<op::Sqrt>(n);
}
template <class T>
......@@ -521,7 +524,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{
std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < 0; }))
if (std::any_of(values.begin(), values.end(), [](T i) { return i < T(0); }))
{
throw ngraph_error("Square root of negative value");
}
......@@ -556,6 +559,11 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::relu<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Sign>(unary))
{
runtime::reference::sign<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{
runtime::reference::sqrt<T>(
......@@ -603,33 +611,59 @@ void pass::ConstantFolding::construct_constant_unary()
func = handler->second(unary_match.get());
}
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_unary<int>(constant_match, unary_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_unary<int8_t>(constant_match, unary_match, func));
return true;
}
else if (type == element::f32)
switch (type.get_type_enum())
{
replace_node(m.get_match_root(),
fold_constant_unary<float>(constant_match, unary_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_unary<double>(constant_match, unary_match, func));
return true;
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_unary_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_unary_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_unary<char>(constant_match, unary_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_unary<bfloat16>(constant_match, unary_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_unary<float16>(constant_match, unary_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_unary<float>(constant_match, unary_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_unary<double>(constant_match, unary_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_unary<int8_t>(constant_match, unary_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_unary<int16_t>(constant_match, unary_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_unary<int32_t>(constant_match, unary_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_unary<int64_t>(constant_match, unary_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_unary<uint8_t>(constant_match, unary_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_unary<uint16_t>(constant_match, unary_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_unary<uint32_t>(constant_match, unary_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_unary<uint64_t>(constant_match, unary_match, func);
break;
}
return false;
replace_node(m.get_match_root(), replacement);
return true;
};
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
......
......@@ -470,6 +470,12 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Sign)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sign);
}
#define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher()
......@@ -536,6 +542,7 @@ namespace ngraph
REGISTER_CF_BUILDER(Negative);
REGISTER_CF_BUILDER(Relu);
REGISTER_CF_BUILDER(Sqrt);
REGISTER_CF_BUILDER(Sign);
}
}
}
......@@ -30,7 +30,7 @@ namespace ngraph
for (size_t i = 0; i < count; i++)
{
// TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < 0 ? -arg[i] : arg[i]);
out[i] = (arg[i] < T(0) ? T(-arg[i]) : arg[i]);
}
}
}
......
......@@ -29,7 +29,7 @@ namespace ngraph
{
for (size_t i = 0; i < count; i++)
{
out[i] = (arg[i] < 0 ? -1 : (arg[i] > 0 ? 1 : 0));
out[i] = (arg[i] < T(0) ? T(-1) : (arg[i] > T(0) ? T(1) : T(0)));
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment