Commit 80b4b1da authored by Fabian Boemer's avatar Fabian Boemer Committed by Scott Cyphers

Added sqrt to constant folding (#2610)

* Added sqrt to constant folding

* Added negative sqrt checking
parent 8d45fc33
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
...@@ -48,6 +49,7 @@ ...@@ -48,6 +49,7 @@
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
using namespace std; using namespace std;
...@@ -385,7 +387,7 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -385,7 +387,7 @@ void pass::ConstantFolding::construct_constant_binary()
bool is_supported_unary_op(std::shared_ptr<Node> n) bool is_supported_unary_op(std::shared_ptr<Node> n)
{ {
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) || return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n); std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
} }
template <class T> template <class T>
...@@ -410,6 +412,16 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant, ...@@ -410,6 +412,16 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::relu<T>( runtime::reference::relu<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape)); constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{
std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < 0; }))
{
throw ngraph_error("Square root of negative value");
}
runtime::reference::sqrt<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape));
}
else else
{ {
NGRAPH_ASSERT(false) << "must be consistent with is_supported_unary_op"; NGRAPH_ASSERT(false) << "must be consistent with is_supported_unary_op";
......
...@@ -178,9 +178,11 @@ TEST(constant_folding, constant_unary_binary) ...@@ -178,9 +178,11 @@ TEST(constant_folding, constant_unary_binary)
vector<int> values_a{1, 2, 3, 4}; vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4}; vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1}; vector<int> values_c{-1, -1, -1, -1};
vector<int> values_d{1, 4, 9, 16};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a); auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b); auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c); auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
auto d = make_shared<op::Constant>(element::i32, shape_in, values_d);
auto add = a + b; auto add = a + b;
auto sub = a - b; auto sub = a - b;
...@@ -190,9 +192,12 @@ TEST(constant_folding, constant_unary_binary) ...@@ -190,9 +192,12 @@ TEST(constant_folding, constant_unary_binary)
auto max = make_shared<op::Maximum>(a, c); auto max = make_shared<op::Maximum>(a, c);
auto absn = make_shared<op::Abs>(c); auto absn = make_shared<op::Abs>(c);
auto neg = make_shared<op::Negative>(c); auto neg = make_shared<op::Negative>(c);
auto sqrt = make_shared<op::Sqrt>(d);
auto neg_sqrt = make_shared<op::Sqrt>(c);
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg}, auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg, sqrt},
ParameterVector{}); ParameterVector{});
auto f_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.register_pass<pass::ConstantFolding>();
...@@ -206,6 +211,7 @@ TEST(constant_folding, constant_unary_binary) ...@@ -206,6 +211,7 @@ TEST(constant_folding, constant_unary_binary)
vector<int> min_expected{-1, -1, -1, -1}; vector<int> min_expected{-1, -1, -1, -1};
vector<int> max_expected{1, 2, 3, 4}; vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1}; vector<int> abs_neg_expected{1, 1, 1, 1};
vector<int> sqrt_expected{1, 2, 3, 4};
ASSERT_EQ(get_result_constant<int>(f, 0), add_expected); ASSERT_EQ(get_result_constant<int>(f, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected); ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected);
...@@ -215,6 +221,8 @@ TEST(constant_folding, constant_unary_binary) ...@@ -215,6 +221,8 @@ TEST(constant_folding, constant_unary_binary)
ASSERT_EQ(get_result_constant<int>(f, 5), max_expected); ASSERT_EQ(get_result_constant<int>(f, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected); ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected); ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 8), sqrt_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(f_error));
} }
TEST(constant_folding, const_dequantize) TEST(constant_folding, const_dequantize)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment