Commit 3f017a1e authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use cpu kernel for constant folding. (#2538)

* Use cpu kernel for constant folding.

* Add default empty map.

* Fix a bug.

* Add new files.

* Address PR feedback.

* Check constant folding map before checking type for unary and binary ops.

* Address PR feedback.

* Address PR feedback.

* Use all_close_f.

Add relu unit test.

Make changes for sqrt and pad.

* Fix a bug.
parent 9fea22b2
...@@ -51,42 +51,71 @@ ...@@ -51,42 +51,71 @@
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
template <class T> template <class T>
shared_ptr<op::Constant> make_constant_reshape(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant,
shared_ptr<op::Reshape> reshape) shared_ptr<op::Reshape> reshape,
NodeExecutorTy func)
{ {
auto out_shape = reshape->get_shape(); auto out_shape = reshape->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
runtime::reference::reshape<T>(constant->get_vector<T>().data(), if (func != nullptr)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
}
else
{
runtime::reference::reshape<T>(constant->get_data_ptr<T>(),
out_vec.data(), out_vec.data(),
constant->get_shape(), constant->get_shape(),
reshape->get_input_order(), reshape->get_input_order(),
out_shape); out_shape);
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
template <class T> template <class T>
shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad) shared_ptr<op::Pad> pad,
NodeExecutorTy func)
{ {
auto out_shape = pad->get_shape(); auto out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
auto pad_value = std::static_pointer_cast<op::Constant>(pad->get_argument(1)); auto pad_value = std::static_pointer_cast<op::Constant>(pad->get_argument(1));
runtime::reference::pad<T>(constant->get_vector<T>().data(), if (func != nullptr)
pad_value->get_vector<T>().data(), {
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
inputs.push_back(const_cast<void*>(pad_value->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
}
else
{
runtime::reference::pad<T>(constant->get_data_ptr<T>(),
pad_value->get_data_ptr<T>(),
out_vec.data(), out_vec.data(),
constant->get_shape(), constant->get_shape(),
out_shape, out_shape,
pad->get_padding_below(), pad->get_padding_below(),
pad->get_padding_above(), pad->get_padding_above(),
pad->get_pad_mode()); pad->get_pad_mode());
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
...@@ -105,7 +134,7 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -105,7 +134,7 @@ void pass::ConstantFolding::construct_constant_pad()
auto pad = make_shared<op::Pad>( auto pad = make_shared<op::Pad>(
constant_label, pad_value_label, padding_below, padding_above, pad_mode); constant_label, pad_value_label, padding_below, padding_above, pad_mode);
auto constant_pad_callback = [constant_label](pattern::Matcher& m) { auto constant_pad_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_pad_callback against node = " NGRAPH_DEBUG << "In callback for constant_pad_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -114,25 +143,37 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -114,25 +143,37 @@ void pass::ConstantFolding::construct_constant_pad()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto pad_match = static_pointer_cast<op::Pad>(m.get_match_root()); auto pad_match = static_pointer_cast<op::Pad>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Pad)));
NGRAPH_ASSERT(handler != m_cfmap.end()) << "constant folding map should have pad entry";
func = handler->second(pad_match.get());
}
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
{ {
replace_node(m.get_match_root(), make_constant_pad<int>(constant_match, pad_match)); replace_node(m.get_match_root(),
fold_constant_pad<int>(constant_match, pad_match, func));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), make_constant_pad<int8_t>(constant_match, pad_match)); replace_node(m.get_match_root(),
fold_constant_pad<int8_t>(constant_match, pad_match, func));
return true; return true;
} }
else if (type == element::f32) else if (type == element::f32)
{ {
replace_node(m.get_match_root(), make_constant_pad<float>(constant_match, pad_match)); replace_node(m.get_match_root(),
fold_constant_pad<float>(constant_match, pad_match, func));
return true; return true;
} }
else if (type == element::f64) else if (type == element::f64)
{ {
replace_node(m.get_match_root(), make_constant_pad<double>(constant_match, pad_match)); replace_node(m.get_match_root(),
fold_constant_pad<double>(constant_match, pad_match, func));
return true; return true;
} }
...@@ -150,7 +191,7 @@ void pass::ConstantFolding::construct_constant_reshape() ...@@ -150,7 +191,7 @@ void pass::ConstantFolding::construct_constant_reshape()
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1}); auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [constant_label](pattern::Matcher& m) { auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = " NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -159,29 +200,38 @@ void pass::ConstantFolding::construct_constant_reshape() ...@@ -159,29 +200,38 @@ void pass::ConstantFolding::construct_constant_reshape()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root()); auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_ASSERT(handler != m_cfmap.end())
<< "constant folding map should have reshape entry";
func = handler->second(reshape_match.get());
}
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_reshape<int>(constant_match, reshape_match)); fold_constant_reshape<int>(constant_match, reshape_match, func));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_reshape<int8_t>(constant_match, reshape_match)); fold_constant_reshape<int8_t>(constant_match, reshape_match, func));
return true; return true;
} }
else if (type == element::f32) else if (type == element::f32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_reshape<float>(constant_match, reshape_match)); fold_constant_reshape<float>(constant_match, reshape_match, func));
return true; return true;
} }
else if (type == element::f64) else if (type == element::f64)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_reshape<double>(constant_match, reshape_match)); fold_constant_reshape<double>(constant_match, reshape_match, func));
return true; return true;
} }
...@@ -194,17 +244,30 @@ void pass::ConstantFolding::construct_constant_reshape() ...@@ -194,17 +244,30 @@ void pass::ConstantFolding::construct_constant_reshape()
} }
template <class T> template <class T>
shared_ptr<op::Constant> make_constant_broadcast(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> constant,
shared_ptr<op::Broadcast> broadcast) shared_ptr<op::Broadcast> broadcast,
NodeExecutorTy func)
{ {
auto out_shape = broadcast->get_shape(); auto out_shape = broadcast->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
runtime::reference::broadcast<T>(constant->get_vector<T>().data(), if (func != nullptr)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
}
else
{
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(), out_vec.data(),
constant->get_shape(), constant->get_shape(),
out_shape, out_shape,
broadcast->get_broadcast_axes()); broadcast->get_broadcast_axes());
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
...@@ -216,7 +279,7 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -216,7 +279,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto broadcast = make_shared<op::Broadcast>(constant_label, Shape{2, 4}, AxisSet{1}); auto broadcast = make_shared<op::Broadcast>(constant_label, Shape{2, 4}, AxisSet{1});
auto constant_broadcast_callback = [constant_label](pattern::Matcher& m) { auto constant_broadcast_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_broadcast_callback against node = " NGRAPH_DEBUG << "In callback for constant_broadcast_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -225,29 +288,38 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -225,29 +288,38 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto broadcast_match = static_pointer_cast<op::Broadcast>(m.get_match_root()); auto broadcast_match = static_pointer_cast<op::Broadcast>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Broadcast)));
NGRAPH_ASSERT(handler != m_cfmap.end())
<< "constant folding map should have broadcast entry";
func = handler->second(broadcast_match.get());
}
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_broadcast<int>(constant_match, broadcast_match)); fold_constant_broadcast<int>(constant_match, broadcast_match, func));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_broadcast<int8_t>(constant_match, broadcast_match)); fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func));
return true; return true;
} }
else if (type == element::f32) else if (type == element::f32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_broadcast<float>(constant_match, broadcast_match)); fold_constant_broadcast<float>(constant_match, broadcast_match, func));
return true; return true;
} }
else if (type == element::f64) else if (type == element::f64)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_broadcast<double>(constant_match, broadcast_match)); fold_constant_broadcast<double>(constant_match, broadcast_match, func));
return true; return true;
} }
...@@ -260,59 +332,61 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -260,59 +332,61 @@ void pass::ConstantFolding::construct_constant_broadcast()
} }
template <class T> template <class T>
shared_ptr<op::Constant> make_constant_binary(shared_ptr<op::Constant> a, shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b, shared_ptr<op::Constant> b,
shared_ptr<Node> binary) shared_ptr<Node> binary,
NodeExecutorTy func)
{ {
auto out_shape = binary->get_shape(); auto out_shape = binary->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
if (func != nullptr)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
}
else
{
if (std::dynamic_pointer_cast<op::Add>(binary)) if (std::dynamic_pointer_cast<op::Add>(binary))
{ {
runtime::reference::add<T>(a->get_vector<T>().data(), runtime::reference::add<T>(
b->get_vector<T>().data(), a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
out_vec.data(),
shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Subtract>(binary)) else if (std::dynamic_pointer_cast<op::Subtract>(binary))
{ {
runtime::reference::subtract<T>(a->get_vector<T>().data(), runtime::reference::subtract<T>(
b->get_vector<T>().data(), a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
out_vec.data(),
shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Multiply>(binary)) else if (std::dynamic_pointer_cast<op::Multiply>(binary))
{ {
runtime::reference::multiply<T>(a->get_vector<T>().data(), runtime::reference::multiply<T>(
b->get_vector<T>().data(), a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
out_vec.data(),
shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Divide>(binary)) else if (std::dynamic_pointer_cast<op::Divide>(binary))
{ {
runtime::reference::divide<T>(a->get_vector<T>().data(), runtime::reference::divide<T>(
b->get_vector<T>().data(), a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
out_vec.data(),
shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Minimum>(binary)) else if (std::dynamic_pointer_cast<op::Minimum>(binary))
{ {
runtime::reference::minimum<T>(a->get_vector<T>().data(), runtime::reference::minimum<T>(
b->get_vector<T>().data(), a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
out_vec.data(),
shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Maximum>(binary)) else if (std::dynamic_pointer_cast<op::Maximum>(binary))
{ {
runtime::reference::maximum<T>(a->get_vector<T>().data(), runtime::reference::maximum<T>(
b->get_vector<T>().data(), a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
out_vec.data(),
shape_size(out_shape));
} }
else else
{ {
NGRAPH_ASSERT(false) NGRAPH_ASSERT(false)
<< "make_constant_binary must be consistent with is_supported_binary_op"; << "fold_constant_binary must be consistent with is_supported_binary_op";
}
} }
return make_shared<op::Constant>(a->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(a->get_element_type(), out_shape, out_vec);
...@@ -335,7 +409,7 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -335,7 +409,7 @@ void pass::ConstantFolding::construct_constant_binary()
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>(); auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b}); auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto constant_binary_callback = [a, b](pattern::Matcher& m) { auto constant_binary_callback = [&, a, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_binary_callback against node = " NGRAPH_DEBUG << "In callback for constant_binary_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -350,29 +424,39 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -350,29 +424,39 @@ void pass::ConstantFolding::construct_constant_binary()
return false; return false;
} }
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto& node = *binary_match;
auto handler = m_cfmap.find(type_index(typeid(node)));
NGRAPH_ASSERT(handler != m_cfmap.end())
<< "constant folding map should have an entry for " << binary_match->get_name();
func = handler->second(binary_match.get());
}
auto type = a_match->get_element_type(); auto type = a_match->get_element_type();
if (type == element::i32) if (type == element::i32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_binary<int>(a_match, b_match, binary_match)); fold_constant_binary<int>(a_match, b_match, binary_match, func));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_binary<int8_t>(a_match, b_match, binary_match)); fold_constant_binary<int8_t>(a_match, b_match, binary_match, func));
return true; return true;
} }
else if (type == element::f32) else if (type == element::f32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_binary<float>(a_match, b_match, binary_match)); fold_constant_binary<float>(a_match, b_match, binary_match, func));
return true; return true;
} }
else if (type == element::f64) else if (type == element::f64)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_binary<double>(a_match, b_match, binary_match)); fold_constant_binary<double>(a_match, b_match, binary_match, func));
return true; return true;
} }
...@@ -391,41 +475,59 @@ bool is_supported_unary_op(std::shared_ptr<Node> n) ...@@ -391,41 +475,59 @@ bool is_supported_unary_op(std::shared_ptr<Node> n)
} }
template <class T> template <class T>
shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
shared_ptr<Node> unary) shared_ptr<Node> unary,
NodeExecutorTy func)
{ {
//check sqrt arg
if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{
std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < 0; }))
{
throw ngraph_error("Square root of negative value");
}
}
auto out_shape = unary->get_shape(); auto out_shape = unary->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
if (func != nullptr)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
}
else
{
if (std::dynamic_pointer_cast<op::Abs>(unary)) if (std::dynamic_pointer_cast<op::Abs>(unary))
{ {
runtime::reference::abs<T>( runtime::reference::abs<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Negative>(unary)) else if (std::dynamic_pointer_cast<op::Negative>(unary))
{ {
runtime::reference::negate<T>( runtime::reference::negate<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Relu>(unary)) else if (std::dynamic_pointer_cast<op::Relu>(unary))
{ {
runtime::reference::relu<T>( runtime::reference::relu<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Sqrt>(unary)) else if (std::dynamic_pointer_cast<op::Sqrt>(unary))
{ {
std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < 0; }))
{
throw ngraph_error("Square root of negative value");
}
runtime::reference::sqrt<T>( runtime::reference::sqrt<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else else
{ {
NGRAPH_ASSERT(false) << "must be consistent with is_supported_unary_op"; NGRAPH_ASSERT(false) << "must be consistent with is_supported_unary_op";
} }
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
...@@ -438,8 +540,8 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -438,8 +540,8 @@ void pass::ConstantFolding::construct_constant_unary()
auto uea = auto uea =
std::make_shared<pattern::op::Any>(constant_label, is_uea, NodeVector{constant_label}); std::make_shared<pattern::op::Any>(constant_label, is_uea, NodeVector{constant_label});
auto constant_unary_callback = [constant_label](pattern::Matcher& m) { auto constant_unary_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = " NGRAPH_DEBUG << "In callback for constant_unary_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -452,28 +554,39 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -452,28 +554,39 @@ void pass::ConstantFolding::construct_constant_unary()
return false; return false;
} }
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto& node = *unary_match;
auto handler = m_cfmap.find(type_index(typeid(node)));
NGRAPH_ASSERT(handler != m_cfmap.end())
<< "constant folding map should have an entry for " << unary_match->get_name();
func = handler->second(unary_match.get());
}
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
{ {
replace_node(m.get_match_root(), make_constant_unary<int>(constant_match, unary_match)); replace_node(m.get_match_root(),
fold_constant_unary<int>(constant_match, unary_match, func));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_unary<int8_t>(constant_match, unary_match)); fold_constant_unary<int8_t>(constant_match, unary_match, func));
return true; return true;
} }
else if (type == element::f32) else if (type == element::f32)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_unary<float>(constant_match, unary_match)); fold_constant_unary<float>(constant_match, unary_match, func));
return true; return true;
} }
else if (type == element::f64) else if (type == element::f64)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_unary<double>(constant_match, unary_match)); fold_constant_unary<double>(constant_match, unary_match, func));
return true; return true;
} }
...@@ -486,7 +599,7 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -486,7 +599,7 @@ void pass::ConstantFolding::construct_constant_unary()
} }
template <class QUANT, class REAL> template <class QUANT, class REAL>
shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_dequantize(shared_ptr<op::Constant> constant,
shared_ptr<op::Dequantize> dequant, shared_ptr<op::Dequantize> dequant,
shared_ptr<op::Constant> scale, shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset) shared_ptr<op::Constant> offset)
...@@ -538,14 +651,14 @@ void pass::ConstantFolding::construct_constant_dequantize() ...@@ -538,14 +651,14 @@ void pass::ConstantFolding::construct_constant_dequantize()
if (type == element::u8) if (type == element::u8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_dequantize<uint8_t, float>( fold_constant_dequantize<uint8_t, float>(
constant_match, dequantize_op, scale, offset)); constant_match, dequantize_op, scale, offset));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_dequantize<int8_t, float>( fold_constant_dequantize<int8_t, float>(
constant_match, dequantize_op, scale, offset)); constant_match, dequantize_op, scale, offset));
return true; return true;
} }
...@@ -559,7 +672,7 @@ void pass::ConstantFolding::construct_constant_dequantize() ...@@ -559,7 +672,7 @@ void pass::ConstantFolding::construct_constant_dequantize()
} }
template <class REAL, class QUANT> template <class REAL, class QUANT>
shared_ptr<op::Constant> make_constant_quantize(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_quantize(shared_ptr<op::Constant> constant,
shared_ptr<op::Quantize> quant, shared_ptr<op::Quantize> quant,
shared_ptr<op::Constant> scale, shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset) shared_ptr<op::Constant> offset)
...@@ -614,14 +727,14 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -614,14 +727,14 @@ void pass::ConstantFolding::construct_constant_quantize()
{ {
replace_node( replace_node(
m.get_match_root(), m.get_match_root(),
make_constant_quantize<float, uint8_t>(constant_match, quantize_op, scale, offset)); fold_constant_quantize<float, uint8_t>(constant_match, quantize_op, scale, offset));
return true; return true;
} }
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node( replace_node(
m.get_match_root(), m.get_match_root(),
make_constant_quantize<float, int8_t>(constant_match, quantize_op, scale, offset)); fold_constant_quantize<float, int8_t>(constant_match, quantize_op, scale, offset));
return true; return true;
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -40,9 +41,10 @@ public: ...@@ -40,9 +41,10 @@ public:
QUANTIZE QUANTIZE
}; };
ConstantFolding() ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
: GraphRewrite() : GraphRewrite()
{ {
m_cfmap = cfmap;
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast(); construct_constant_broadcast();
construct_constant_pad(); construct_constant_pad();
...@@ -54,9 +56,11 @@ public: ...@@ -54,9 +56,11 @@ public:
//this allows to specify the order in which matchers will be run //this allows to specify the order in which matchers will be run
//and also allows to register the same matcher more than once //and also allows to register the same matcher more than once
ConstantFolding(const std::vector<CFTransformations>& transformations) ConstantFolding(const std::vector<CFTransformations>& transformations,
const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
: GraphRewrite() : GraphRewrite()
{ {
m_cfmap = cfmap;
for (auto cft : transformations) for (auto cft : transformations)
{ {
switch (cft) switch (cft)
...@@ -80,4 +84,6 @@ private: ...@@ -80,4 +84,6 @@ private:
void construct_constant_binary(); void construct_constant_binary();
void construct_constant_quantize(); void construct_constant_quantize();
void construct_constant_dequantize(); void construct_constant_dequantize();
ngraph::BuildNodeExecutorMap m_cfmap;
}; };
...@@ -29,22 +29,18 @@ namespace ngraph ...@@ -29,22 +29,18 @@ namespace ngraph
{ {
namespace cpu namespace cpu
{ {
template <> static void get_broadcast_kernel(
void Builder::BUILDER_DECL(ngraph::op::Broadcast) const ngraph::Node* node,
std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)>& kernel,
Shape& expanded_input_shape,
Shape& out_shape,
size_t& size)
{ {
auto& functors = external_function->get_functors();
auto broadcast = static_cast<const ngraph::op::Broadcast*>(node); auto broadcast = static_cast<const ngraph::op::Broadcast*>(node);
auto broadcast_axes = broadcast->get_broadcast_axes(); auto broadcast_axes = broadcast->get_broadcast_axes();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); auto arg_shape = broadcast->get_argument(0)->get_shape();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); out_shape = broadcast->get_shape();
auto arg_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
// TODO(jmenon): Shape transformations, rank reduction etc. needs to be general
// and not in any one builder. Move this to the Halide analysis phase.
// Transform output shape - ex. [4, 1, 2, 2] -> [4, 1, 4] // Transform output shape - ex. [4, 1, 2, 2] -> [4, 1, 4]
// if we're not broadcasting along axes 2 and 3 // if we're not broadcasting along axes 2 and 3
...@@ -96,9 +92,7 @@ namespace ngraph ...@@ -96,9 +92,7 @@ namespace ngraph
else else
{ {
broadcast_axes.erase(i); broadcast_axes.erase(i);
// TODO(jmenon): This needs to be rewritten
// when it gets moved to the analysis pass
// that doesn't use AxisSet
auto new_bcast_axes = AxisSet{}; auto new_bcast_axes = AxisSet{};
for (auto axis : broadcast_axes) for (auto axis : broadcast_axes)
{ {
...@@ -128,11 +122,7 @@ namespace ngraph ...@@ -128,11 +122,7 @@ namespace ngraph
if (broadcast_axes.empty()) if (broadcast_axes.empty())
{ {
size_t size = out[0].get_size() * out[0].get_element_type().size(); size = shape_size(out_shape) * broadcast->get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
functors.emplace_back(functor);
return; return;
} }
...@@ -146,7 +136,7 @@ namespace ngraph ...@@ -146,7 +136,7 @@ namespace ngraph
// so expand as needed // so expand as needed
// Ex. [2] -> [2, 1] for output shape [2, 4] // Ex. [2] -> [2, 1] for output shape [2, 4]
auto expanded_input_shape = Shape(out_rank, 1); expanded_input_shape = Shape(out_rank, 1);
size_t i = 0; size_t i = 0;
for (size_t j = 0; j < out_rank; j++) for (size_t j = 0; j < out_rank; j++)
{ {
...@@ -160,17 +150,70 @@ namespace ngraph ...@@ -160,17 +150,70 @@ namespace ngraph
} }
} }
SELECT_KERNEL_BY_RANK(kernel,
broadcast->get_input_element_type(0),
out_rank,
runtime::cpu::kernel::broadcast);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Broadcast)
{
std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel; std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel;
Shape expanded_input_shape, out_shape;
size_t size;
get_broadcast_kernel(node, kernel, expanded_input_shape, out_shape, size);
NodeExecutorTy functor;
if (kernel)
{
functor = [kernel, expanded_input_shape, out_shape](
const std::vector<void*> inputs, std::vector<void*> outputs) {
kernel(inputs[0], outputs[0], expanded_input_shape, out_shape, 0);
};
}
else
{
functor = [size](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
memcpy(outputs[0], inputs[0], size);
};
}
return functor;
}
REGISTER_CF_BUILDER(Broadcast);
SELECT_KERNEL_BY_RANK( template <>
kernel, args[0].get_element_type(), out_rank, runtime::cpu::kernel::broadcast); void Builder::BUILDER_DECL(ngraph::op::Broadcast)
{
auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto functor = [&, kernel, expanded_input_shape, out_shape]( std::function<decltype(runtime::cpu::kernel::broadcast<float, 2>)> kernel;
Shape expanded_input_shape, out_shape;
size_t size;
get_broadcast_kernel(node, kernel, expanded_input_shape, out_shape, size);
CPUKernelFunctor functor;
if (kernel)
{
functor = [&, kernel, expanded_input_shape, out_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, expanded_input_shape, out_shape, ectx->arena); kernel(
arg_tensor, out_tensor, expanded_input_shape, out_shape, ectx->arena);
};
functors.emplace_back(functor);
}
else
{
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
} }
}
REGISTER_OP_BUILDER(Broadcast); REGISTER_OP_BUILDER(Broadcast);
} }
......
...@@ -50,7 +50,7 @@ namespace ngraph ...@@ -50,7 +50,7 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT) if (pad_mode == ngraph::op::PadMode::CONSTANT)
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
...@@ -97,6 +97,64 @@ namespace ngraph ...@@ -97,6 +97,64 @@ namespace ngraph
} }
REGISTER_OP_BUILDER(Pad); REGISTER_OP_BUILDER(Pad);
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Pad)
{
auto pad = static_cast<const ngraph::op::Pad*>(node);
auto arg_shape = pad->get_argument(0)->get_shape();
auto out_shape = pad->get_shape();
auto padding_below = pad->get_padding_below();
auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT)
{
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
pad->get_input_element_type(0),
arg_shape.size(),
runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0],
outputs[0],
inputs[1],
arg_shape,
out_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()),
CoordinateDiff(padding_above.begin(), padding_above.end()),
0);
};
return functor;
}
else
{
std::function<decltype(runtime::cpu::kernel::pad_ref<float>)> kernel;
SELECT_KERNEL(
kernel, pad->get_input_element_type(0), runtime::cpu::kernel::pad_ref);
auto functor =
[kernel, arg_shape, out_shape, padding_below, padding_above, pad_mode](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0],
inputs[1],
outputs[0],
arg_shape,
out_shape,
padding_below,
padding_above,
pad_mode,
0);
};
return functor;
}
}
REGISTER_CF_BUILDER(Pad);
} }
} }
} }
...@@ -31,15 +31,31 @@ namespace ngraph ...@@ -31,15 +31,31 @@ namespace ngraph
{ {
namespace cpu namespace cpu
{ {
template <> static void get_reshape_kernel(
void Builder::BUILDER_DECL(ngraph::op::Reshape) const ngraph::Node* node,
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)>& kernel,
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)>& ref_kernel,
Shape& arg_shape,
Shape& result_shape,
AxisVector& input_order,
size_t& size,
bool& skip_reshape)
{ {
auto& functors = external_function->get_functors(); auto reshape = static_cast<const ngraph::op::Reshape*>(node);
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); arg_shape = reshape->get_argument(0)->get_shape();
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto arg_rank = arg_shape.size();
auto reshape = static_cast<const ngraph::op::Reshape*>(node); result_shape = reshape->get_output_shape();
auto result_rank = result_shape.size();
auto& result_element_type = reshape->get_element_type();
input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end());
auto result_size = shape_size(result_shape);
size = result_size * result_element_type.size();
auto can_skip_reshape = [&]() { auto can_skip_reshape = [&]() {
if (!reshape->get_is_transpose()) if (!reshape->get_is_transpose())
...@@ -56,41 +72,15 @@ namespace ngraph ...@@ -56,41 +72,15 @@ namespace ngraph
if (can_skip_reshape()) if (can_skip_reshape())
{ {
size_t size = out[0].get_size() * out[0].get_element_type().size(); skip_reshape = true;
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (out_tensor != arg_tensor)
{
memcpy(out_tensor, arg_tensor, size);
}
};
functors.emplace_back(functor);
return; return;
} }
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape();
auto result_rank = result_shape.size();
auto& result_element_type = out[0].get_element_type();
auto input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end());
auto result_size = shape_size(result_shape);
if (same_layout || result_size < 2) if (same_layout || result_size < 2)
{ {
size_t size = out[0].get_size() * out[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
functors.emplace_back(functor);
return; return;
} }
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
if (arg_rank == 1) if (arg_rank == 1)
{ {
SELECT_KERNEL_BY_RANK( SELECT_KERNEL_BY_RANK(
...@@ -113,29 +103,128 @@ namespace ngraph ...@@ -113,29 +103,128 @@ namespace ngraph
} }
else else
{ {
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
SELECT_KERNEL( SELECT_KERNEL(
ref_kernel, result_element_type, runtime::cpu::kernel::reshape_ref); ref_kernel, result_element_type, runtime::cpu::kernel::reshape_ref);
}
}
auto functor = [&, ref_kernel, arg_shape, input_order, result_shape]( template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Reshape)
{
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
Shape arg_shape, result_shape;
AxisVector input_order;
size_t size;
bool skip_reshape = false;
get_reshape_kernel(node,
kernel,
ref_kernel,
arg_shape,
result_shape,
input_order,
size,
skip_reshape);
NodeExecutorTy functor;
if (kernel)
{
functor = [kernel, arg_shape, input_order, result_shape](
const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0], outputs[0], arg_shape, input_order, result_shape, 0);
};
}
else if (ref_kernel)
{
functor = [ref_kernel, arg_shape, input_order, result_shape](
std::vector<void*> inputs, std::vector<void*> outputs) {
ref_kernel(inputs[0], outputs[0], arg_shape, input_order, result_shape, 0);
};
}
else if (skip_reshape)
{
functor = [size](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
if (inputs[0] != outputs[0])
{
memcpy(outputs[0], inputs[0], size);
}
};
}
else
{
functor = [size](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
memcpy(outputs[0], inputs[0], size);
};
}
return functor;
}
REGISTER_CF_BUILDER(Reshape);
template <>
void Builder::BUILDER_DECL(ngraph::op::Reshape)
{
auto& functors = external_function->get_functors();
auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
std::function<decltype(runtime::cpu::kernel::reshape_1d<float, 2>)> kernel;
std::function<decltype(runtime::cpu::kernel::reshape_ref<float>)> ref_kernel;
Shape arg_shape, result_shape;
AxisVector input_order;
size_t size;
bool skip_reshape = false;
get_reshape_kernel(node,
kernel,
ref_kernel,
arg_shape,
result_shape,
input_order,
size,
skip_reshape);
CPUKernelFunctor functor;
if (kernel)
{
functor = [&, kernel, arg_shape, input_order, result_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ref_kernel(arg_tensor, kernel(arg_tensor,
out_tensor, out_tensor,
arg_shape, arg_shape,
input_order, input_order,
result_shape, result_shape,
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor);
return;
} }
else if (ref_kernel)
auto functor = [&, kernel, arg_shape, input_order, result_shape]( {
functor = [&, ref_kernel, arg_shape, input_order, result_shape](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) { CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel( ref_kernel(arg_tensor,
arg_tensor, out_tensor, arg_shape, input_order, result_shape, ectx->arena); out_tensor,
arg_shape,
input_order,
result_shape,
ectx->arena);
};
}
else if (skip_reshape)
{
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
if (out_tensor != arg_tensor)
{
memcpy(out_tensor, arg_tensor, size);
}
}; };
}
else
{
functor = [&, size](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
memcpy(out_tensor, arg_tensor, size);
};
}
functors.emplace_back(functor); functors.emplace_back(functor);
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp" #include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp" #include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp" #include "ngraph/op/atan.hpp"
...@@ -53,6 +54,7 @@ ...@@ -53,6 +54,7 @@
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp" #include "ngraph/op/sin.hpp"
...@@ -65,6 +67,7 @@ ...@@ -65,6 +67,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/kernel/abs.hpp" #include "ngraph/runtime/cpu/kernel/abs.hpp"
#include "ngraph/runtime/cpu/kernel/acos.hpp" #include "ngraph/runtime/cpu/kernel/acos.hpp"
#include "ngraph/runtime/cpu/kernel/add.hpp"
#include "ngraph/runtime/cpu/kernel/and.hpp" #include "ngraph/runtime/cpu/kernel/and.hpp"
#include "ngraph/runtime/cpu/kernel/asin.hpp" #include "ngraph/runtime/cpu/kernel/asin.hpp"
#include "ngraph/runtime/cpu/kernel/atan.hpp" #include "ngraph/runtime/cpu/kernel/atan.hpp"
...@@ -89,6 +92,7 @@ ...@@ -89,6 +92,7 @@
#include "ngraph/runtime/cpu/kernel/not.hpp" #include "ngraph/runtime/cpu/kernel/not.hpp"
#include "ngraph/runtime/cpu/kernel/not_equal.hpp" #include "ngraph/runtime/cpu/kernel/not_equal.hpp"
#include "ngraph/runtime/cpu/kernel/or.hpp" #include "ngraph/runtime/cpu/kernel/or.hpp"
#include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/runtime/cpu/kernel/result.hpp" #include "ngraph/runtime/cpu/kernel/result.hpp"
#include "ngraph/runtime/cpu/kernel/sign.hpp" #include "ngraph/runtime/cpu/kernel/sign.hpp"
#include "ngraph/runtime/cpu/kernel/sin.hpp" #include "ngraph/runtime/cpu/kernel/sin.hpp"
...@@ -365,6 +369,66 @@ namespace ngraph ...@@ -365,6 +369,66 @@ namespace ngraph
functors.emplace_back(functor); functors.emplace_back(functor);
} }
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Add)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::add);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Subtract)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::subtract);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Multiply)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::multiply);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Divide)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::divide);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Minimum)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::minimum);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Maximum)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::maximum);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Abs)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::abs);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Negative)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::negative);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Relu)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::relu);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Sqrt)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt);
}
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher() BuildOpMap& GetGlobalBuildDispatcher()
...@@ -379,6 +443,12 @@ namespace ngraph ...@@ -379,6 +443,12 @@ namespace ngraph
return build_dispatcher; return build_dispatcher;
} }
BuildNodeExecutorMap& GetGlobalCFDispatcherCPU()
{
static BuildNodeExecutorMap build_cf_dispatcher_cpu{};
return build_cf_dispatcher_cpu;
}
REGISTER_OP_BUILDER(Constant); REGISTER_OP_BUILDER(Constant);
REGISTER_OP_BUILDER(Result); REGISTER_OP_BUILDER(Result);
REGISTER_OP_BUILDER(Subtract); REGISTER_OP_BUILDER(Subtract);
...@@ -414,6 +484,17 @@ namespace ngraph ...@@ -414,6 +484,17 @@ namespace ngraph
REGISTER_OP_BUILDER(Minimum); REGISTER_OP_BUILDER(Minimum);
REGISTER_OP_BUILDER(And); REGISTER_OP_BUILDER(And);
REGISTER_OP_BUILDER(Or); REGISTER_OP_BUILDER(Or);
REGISTER_CF_BUILDER(Add);
REGISTER_CF_BUILDER(Subtract);
REGISTER_CF_BUILDER(Multiply);
REGISTER_CF_BUILDER(Divide);
REGISTER_CF_BUILDER(Minimum);
REGISTER_CF_BUILDER(Maximum);
REGISTER_CF_BUILDER(Abs);
REGISTER_CF_BUILDER(Negative);
REGISTER_CF_BUILDER(Relu);
REGISTER_CF_BUILDER(Sqrt);
} }
} }
} }
...@@ -232,6 +232,32 @@ ...@@ -232,6 +232,32 @@
}; \ }; \
functors.emplace_back(functor); functors.emplace_back(functor);
#define BUILD_UNARY_ELEMWISE_CF_FUNCTOR(OP) \
std::function<void(void*, void*, size_t, int)> kernel; \
\
SELECT_KERNEL(kernel, node->get_input_element_type(0), OP); \
\
auto element_count = shape_size(node->get_shape()); \
\
auto functor = [&, kernel, element_count](const std::vector<void*>& inputs, \
std::vector<void*>& outputs) { \
kernel(inputs[0], outputs[0], element_count, 0); \
}; \
return functor;
#define BUILD_BINARY_ELEMWISE_CF_FUNCTOR(OP) \
std::function<void(void*, void*, void*, size_t, int)> kernel; \
\
SELECT_KERNEL(kernel, node->get_input_element_type(0), OP); \
\
auto element_count = shape_size(node->get_shape()); \
\
auto functor = [&, kernel, element_count](const std::vector<void*>& inputs, \
std::vector<void*>& outputs) { \
kernel(inputs[0], inputs[1], outputs[0], element_count, 0); \
}; \
return functor;
#define REGISTER_OP_BUILDER(OP) \ #define REGISTER_OP_BUILDER(OP) \
static struct __register_##OP##_builder \ static struct __register_##OP##_builder \
{ \ { \
...@@ -253,6 +279,29 @@ ...@@ -253,6 +279,29 @@
} \ } \
} __register_##OP##_builder_instance; } __register_##OP##_builder_instance;
#define BUILDER_CF_DECL(op_name) CFbuild<op_name>(const ngraph::Node* node)
#define REGISTER_CF_BUILDER(OP) \
static struct __register_##OP##_cf_builder \
{ \
__register_##OP##_cf_builder() \
{ \
GetGlobalCFDispatcherCPU().insert({type_index(typeid(ngraph::op::OP)), \
&runtime::cpu::Builder::CFbuild<ngraph::op::OP>}); \
} \
} __register_##OP##_cf_builder_instance;
#define REGISTER_CPU_CF_BUILDER(OP) \
static struct __register_##OP##_cf_builder \
{ \
__register_##OP##_cf_builder() \
{ \
GetGlobalCFDispatcherCPU().insert( \
{type_index(typeid(ngraph::runtime::cpu::op::OP)), \
&runtime::cpu::Builder::CFbuild<ngraph::runtime::cpu::op::OP>}); \
} \
} __register_##OP##_cf_builder_instance;
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -269,6 +318,9 @@ namespace ngraph ...@@ -269,6 +318,9 @@ namespace ngraph
BuildOpMap& GetGlobalBuildDispatcher(); BuildOpMap& GetGlobalBuildDispatcher();
// build the map to use cpu kernel for node execution
BuildNodeExecutorMap& GetGlobalCFDispatcherCPU();
class Builder class Builder
{ {
public: public:
...@@ -282,6 +334,13 @@ namespace ngraph ...@@ -282,6 +334,13 @@ namespace ngraph
"' in CPU builder"); "' in CPU builder");
} }
template <typename OP>
static NodeExecutorTy CFbuild(const ngraph::Node* node)
{
throw unsupported_op("Unimplemented op '" + node->description() +
"' for constant folding in CPU builder");
}
static void nop(CPU_ExternalFunction* external_function, static void nop(CPU_ExternalFunction* external_function,
const ngraph::Node* node, const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args, const std::vector<TensorViewWrapper>& args,
......
...@@ -1140,7 +1140,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1140,7 +1140,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUWorkspaceInsertion, true, runtime::cpu::pass, nv_cwi, false); REGISTER_KNOBBED_PASS_WITH_ARGS(CPUWorkspaceInsertion, true, runtime::cpu::pass, nv_cwi, false);
REGISTER_KNOBBED_PASS_WITH_ARGS(CPUAssignment, true, runtime::cpu::pass, this); REGISTER_KNOBBED_PASS_WITH_ARGS(CPUAssignment, true, runtime::cpu::pass, this);
REGISTER_KNOBBED_PASS(ConstantFolding, false, ngraph::pass); REGISTER_KNOBBED_PASS_WITH_ARGS(
ConstantFolding, true, ngraph::pass, GetGlobalCFDispatcherCPU());
REGISTER_KNOBBED_PASS_WITH_ARGS(CPULayout, true, runtime::cpu::pass, this); REGISTER_KNOBBED_PASS_WITH_ARGS(CPULayout, true, runtime::cpu::pass, this);
REGISTER_KNOBBED_PASS_WITH_ARGS( REGISTER_KNOBBED_PASS_WITH_ARGS(
CommonSubexpressionElimination, true, ngraph::pass, runtime::cpu::get_cse_handlers_map()); CommonSubexpressionElimination, true, ngraph::pass, runtime::cpu::get_cse_handlers_map());
......
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector> #include <vector>
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
...@@ -214,12 +217,21 @@ namespace ngraph ...@@ -214,12 +217,21 @@ namespace ngraph
* This utility takes forward-propogation and back-propagation functions * This utility takes forward-propogation and back-propagation functions
* and turns them into clone functions where the intermediate values of * and turns them into clone functions where the intermediate values of
* the forward prop are added to the output of fprop and the input of the bprop * the forward prop are added to the output of fprop and the input of the bprop
* to avoid repeat calcualtions. * to avoid repeat calculations.
* The last argument is the adjoints coming into the bprop function, the output * The last argument is the adjoints coming into the bprop function, the output
* bprop function will have these nodes as the first N input parameters * bprop function will have these nodes as the first N input parameters
**/ **/
FpropCache cache_fprop(std::shared_ptr<Function> fprop, std::shared_ptr<Function> bprop); FpropCache cache_fprop(std::shared_ptr<Function> fprop, std::shared_ptr<Function> bprop);
// NodeExecutors are used in compiler optimization passes like ConstantFolding to execute a node
// using the supplied input and output memory locations.
// A BuildNodeExecutor returns a backend-specific NodeExecutor for a given Node type
using NodeExecutorTy =
std::function<void(const std::vector<void*>& inputs, std::vector<void*>& outputs)>;
using BuildNodeExecutor = std::function<NodeExecutorTy(const ngraph::Node*)>;
using BuildNodeExecutorMap = std::unordered_map<std::type_index, BuildNodeExecutor>;
enum class CPUTensorRole enum class CPUTensorRole
{ {
INPUT, INPUT,
......
...@@ -31,9 +31,11 @@ ...@@ -31,9 +31,11 @@
#include "ngraph/op/erf.hpp" #include "ngraph/op/erf.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
...@@ -949,6 +951,184 @@ TEST(cpu_test, rotated_pooling) ...@@ -949,6 +951,184 @@ TEST(cpu_test, rotated_pooling)
make_f(false, false), make_f(false, false), "INTERPRETER", "CPU"); // 5D MaxPool make_f(false, false), make_f(false, false), "INTERPRETER", "CPU"); // 5D MaxPool
} }
TEST(cpu_test, constant_reshape)
{
Shape shape_in{2, 4};
Shape shape_out{2, 4, 1};
const vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
auto f = make_shared<Function>(reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
const vector<float> values_out = new_const->get_vector<float>();
EXPECT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, constant_reshape_permute)
{
Shape shape_in{2, 4};
Shape shape_out{4, 2};
vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
auto f = make_shared<Function>(reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
const vector<double> values_out = new_const->get_vector<double>();
const vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
EXPECT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, constant_broadcast)
{
Shape shape_in{2};
Shape shape_out{2, 4};
vector<int> values_in{0, 1};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto broadcast = make_shared<op::Broadcast>(constant, shape_out, AxisSet{1});
auto f = make_shared<Function>(broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out);
}
TEST(cpu_test, constant_pad_exterior)
{
Shape shape_in{2};
vector<int> values_in{777, 888};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto pad_value = make_shared<op::Constant>(element::i32, Shape{}, vector<int>{111});
CoordinateDiff padding_below{1};
CoordinateDiff padding_above{2};
auto broadcast = make_shared<op::Pad>(constant, pad_value, padding_below, padding_above);
auto f = make_shared<Function>(broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Pad>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> padded_values{111, 777, 888, 111, 111};
ASSERT_EQ(padded_values, values_out);
}
template <typename T>
static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
{
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(pos)->get_argument(0));
return new_const->get_vector<T>();
}
TEST(cpu_test, constant_unary_binary)
{
Shape shape_in{4};
vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1};
vector<int> values_d{1, 4, 9, 16};
vector<int> values_e{1, -2, -3, 4};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
auto d = make_shared<op::Constant>(element::i32, shape_in, values_d);
auto e = make_shared<op::Constant>(element::i32, shape_in, values_e);
auto add = a + b;
auto sub = a - b;
auto mul = a * b;
auto divn = a / b;
auto min = make_shared<op::Minimum>(c, a);
auto max = make_shared<op::Maximum>(a, c);
auto absn = make_shared<op::Abs>(c);
auto neg = make_shared<op::Negative>(c);
auto sqrt = make_shared<op::Sqrt>(d);
auto neg_sqrt = make_shared<op::Sqrt>(c);
auto relu = make_shared<op::Relu>(e);
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg, sqrt, relu},
ParameterVector{});
auto f_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(
ngraph::runtime::cpu::GetGlobalCFDispatcherCPU());
pass_manager.run_passes(f);
//expected values
vector<int> add_expected{2, 4, 6, 8};
vector<int> sub_expected{0, 0, 0, 0};
vector<int> mul_expected{1, 4, 9, 16};
vector<int> div_expected{1, 1, 1, 1};
vector<int> min_expected{-1, -1, -1, -1};
vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1};
vector<int> sqrt_expected{1, 2, 3, 4};
vector<int> relu_expected{1, 0, 0, 4};
ASSERT_EQ(get_result_constant<int>(f, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected);
ASSERT_EQ(get_result_constant<int>(f, 2), mul_expected);
ASSERT_EQ(get_result_constant<int>(f, 3), div_expected);
ASSERT_EQ(get_result_constant<int>(f, 4), min_expected);
ASSERT_EQ(get_result_constant<int>(f, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 8), sqrt_expected);
ASSERT_EQ(get_result_constant<int>(f, 9), relu_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(f_error));
}
TEST(cpu_test, conv_test_winograd) TEST(cpu_test, conv_test_winograd)
{ {
/* This test checks for the cpu specific graph pass handling for conv_winograd implementation. /* This test checks for the cpu specific graph pass handling for conv_winograd implementation.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment