Commit 80b4b1da authored by Fabian Boemer's avatar Fabian Boemer Committed by Scott Cyphers

Added sqrt to constant folding (#2610)

* Added sqrt to constant folding

* Added negative sqrt checking
parent 8d45fc33
......@@ -32,6 +32,7 @@
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
......@@ -48,6 +49,7 @@
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
using namespace std;
......@@ -385,7 +387,7 @@ void pass::ConstantFolding::construct_constant_binary()
bool is_supported_unary_op(std::shared_ptr<Node> n)
{
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n);
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
}
template <class T>
......@@ -410,6 +412,16 @@ shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::relu<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{
std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < 0; }))
{
throw ngraph_error("Square root of negative value");
}
runtime::reference::sqrt<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape));
}
else
{
NGRAPH_ASSERT(false) << "must be consistent with is_supported_unary_op";
......
......@@ -178,9 +178,11 @@ TEST(constant_folding, constant_unary_binary)
vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1};
vector<int> values_d{1, 4, 9, 16};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
auto d = make_shared<op::Constant>(element::i32, shape_in, values_d);
auto add = a + b;
auto sub = a - b;
......@@ -190,9 +192,12 @@ TEST(constant_folding, constant_unary_binary)
auto max = make_shared<op::Maximum>(a, c);
auto absn = make_shared<op::Abs>(c);
auto neg = make_shared<op::Negative>(c);
auto sqrt = make_shared<op::Sqrt>(d);
auto neg_sqrt = make_shared<op::Sqrt>(c);
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg},
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg, sqrt},
ParameterVector{});
auto f_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
......@@ -206,6 +211,7 @@ TEST(constant_folding, constant_unary_binary)
vector<int> min_expected{-1, -1, -1, -1};
vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1};
vector<int> sqrt_expected{1, 2, 3, 4};
ASSERT_EQ(get_result_constant<int>(f, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected);
......@@ -215,6 +221,8 @@ TEST(constant_folding, constant_unary_binary)
ASSERT_EQ(get_result_constant<int>(f, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 8), sqrt_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(f_error));
}
TEST(constant_folding, const_dequantize)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment