Commit 31fae943 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Update reference kernels, constant folding to support auto-broadcast directly (#3382)

* WIP

* CHANGE_DYNAMIC_STATE

* Implement full type prop for DynBroadcast when inputs const/static; clean up pass properties

* Add a unit test for the late-constness thing

* Update reference, interp, constant folding to handle autobroadcast

* Document the autobroadcast helper, and fix a corner case (zero's a number too!)

* Tests, and insert a check within the CPU folders for autobroadcast (it's still unsupported)

* Remove unnecessary include

* EXPECT_EQ -> ASSERT_EQ, for consistency

* Replace assert for CPU folders with a fallback
parent 6b5056e8
...@@ -856,7 +856,9 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -856,7 +856,9 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
{ {
auto out_shape = binary->get_shape(); auto out_shape = binary->get_shape();
if (func != nullptr) // NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{ {
vector<Tout> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
vector<void*> inputs; vector<void*> inputs;
...@@ -870,7 +872,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -870,7 +872,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
else else
{ {
if (std::dynamic_pointer_cast<op::Add>(binary)) if (auto add_node = std::dynamic_pointer_cast<op::Add>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -878,10 +880,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -878,10 +880,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(), runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
add_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::And>(binary)) else if (auto and_node = std::dynamic_pointer_cast<op::And>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -889,10 +893,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -889,10 +893,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::logical_and<Tin>(a->get_data_ptr<Tin>(), runtime::reference::logical_and<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Divide>(binary)) else if (auto divide_node = std::dynamic_pointer_cast<op::Divide>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -902,61 +908,73 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -902,61 +908,73 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(), runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape), a->get_shape(),
b->get_shape(),
divide_node->get_autob(),
pythondiv); pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Equal>(binary)) else if (auto equal_node = std::dynamic_pointer_cast<op::Equal>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
equal_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Greater>(binary)) else if (auto greater_node = std::dynamic_pointer_cast<op::Greater>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
greater_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::GreaterEq>(binary)) else if (auto greater_eq_node = std::dynamic_pointer_cast<op::GreaterEq>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
greater_eq_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Less>(binary)) else if (auto less_node = std::dynamic_pointer_cast<op::Less>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
less_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::LessEq>(binary)) else if (auto less_eq_node = std::dynamic_pointer_cast<op::LessEq>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
less_eq_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Maximum>(binary)) else if (auto maximum_node = std::dynamic_pointer_cast<op::Maximum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -964,10 +982,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -964,10 +982,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
maximum_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Minimum>(binary)) else if (auto minimum_node = std::dynamic_pointer_cast<op::Minimum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -975,10 +995,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -975,10 +995,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
minimum_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Multiply>(binary)) else if (auto multiply_node = std::dynamic_pointer_cast<op::Multiply>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -986,20 +1008,24 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -986,20 +1008,24 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(), runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
multiply_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::NotEqual>(binary)) else if (auto not_equal_node = std::dynamic_pointer_cast<op::NotEqual>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
not_equal_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Or>(binary)) else if (auto or_node = std::dynamic_pointer_cast<op::Or>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -1007,10 +1033,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -1007,10 +1033,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::logical_or<Tin>(a->get_data_ptr<Tin>(), runtime::reference::logical_or<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Subtract>(binary)) else if (auto subtract_node = std::dynamic_pointer_cast<op::Subtract>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -1018,10 +1046,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -1018,10 +1046,12 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(), runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
subtract_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (std::dynamic_pointer_cast<op::Xor>(binary)) else if (auto xor_node = std::dynamic_pointer_cast<op::Xor>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -1029,7 +1059,9 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -1029,7 +1059,9 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
runtime::reference::logical_xor<Tin>(a->get_data_ptr<Tin>(), runtime::reference::logical_xor<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
shape_size(out_shape)); a->get_shape(),
b->get_shape(),
xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else else
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "ngraph/pass/assign_layout.hpp" #include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp" #include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp" #include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
...@@ -48,7 +47,6 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f ...@@ -48,7 +47,6 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>(); pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
......
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp" #include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp" #include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
...@@ -39,6 +41,7 @@ ...@@ -39,6 +41,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp" #include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp" #include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp" #include "ngraph/op/experimental/dyn_pad.hpp"
...@@ -46,13 +49,23 @@ ...@@ -46,13 +49,23 @@
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/lrn.hpp" #include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp" #include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp" #include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp" #include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp" #include "ngraph/op/quantized_convolution.hpp"
...@@ -65,8 +78,10 @@ ...@@ -65,8 +78,10 @@
#include "ngraph/op/send.hpp" #include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
...@@ -254,11 +269,13 @@ private: ...@@ -254,11 +269,13 @@ private:
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); const op::Add* add = static_cast<const op::Add*>(&node);
reference::add<T>(args[0]->get_data_ptr<const T>(), reference::add<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
add->get_autob());
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
...@@ -284,11 +301,13 @@ private: ...@@ -284,11 +301,13 @@ private:
} }
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto logical_and = static_cast<const op::And*>(&node);
reference::logical_and(args[0]->get_data_ptr<const T>(), reference::logical_and(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
logical_and->get_autob());
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
...@@ -734,11 +753,12 @@ private: ...@@ -734,11 +753,12 @@ private:
case OP_TYPEID::Divide: case OP_TYPEID::Divide:
{ {
const op::Divide* divop = static_cast<const op::Divide*>(&node); const op::Divide* divop = static_cast<const op::Divide*>(&node);
size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(args[0]->get_data_ptr<const T>(), reference::divide<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count, node.get_input_shape(0),
node.get_input_shape(1),
divop->get_autob(),
divop->is_pythondiv()); divop->is_pythondiv());
break; break;
} }
...@@ -812,11 +832,13 @@ private: ...@@ -812,11 +832,13 @@ private:
} }
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto equal = static_cast<const op::Equal*>(&node);
reference::equal<T>(args[0]->get_data_ptr<const T>(), reference::equal<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
equal->get_autob());
break; break;
} }
case OP_TYPEID::Erf: case OP_TYPEID::Erf:
...@@ -922,38 +944,46 @@ private: ...@@ -922,38 +944,46 @@ private:
} }
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto greater = static_cast<const op::Greater*>(&node);
reference::greater<T>(args[0]->get_data_ptr<const T>(), reference::greater<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
greater->get_autob());
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto greater_eq = static_cast<const op::GreaterEq*>(&node);
reference::greater_eq<T>(args[0]->get_data_ptr<const T>(), reference::greater_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
greater_eq->get_autob());
break; break;
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto less = static_cast<const op::Less*>(&node);
reference::less<T>(args[0]->get_data_ptr<const T>(), reference::less<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
less->get_autob());
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto less_eq = static_cast<const op::LessEq*>(&node);
reference::less_eq<T>(args[0]->get_data_ptr<const T>(), reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
less_eq->get_autob());
break; break;
} }
case OP_TYPEID::Log: case OP_TYPEID::Log:
...@@ -987,11 +1017,13 @@ private: ...@@ -987,11 +1017,13 @@ private:
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto maximum = static_cast<const op::Maximum*>(&node);
reference::maximum<T>(args[0]->get_data_ptr<const T>(), reference::maximum<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
maximum->get_autob());
break; break;
} }
case OP_TYPEID::MaxPool: case OP_TYPEID::MaxPool:
...@@ -1036,20 +1068,24 @@ private: ...@@ -1036,20 +1068,24 @@ private:
} }
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto minimum = static_cast<const op::Minimum*>(&node);
reference::minimum<T>(args[0]->get_data_ptr<const T>(), reference::minimum<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
minimum->get_autob());
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto multiply = static_cast<const op::Multiply*>(&node);
reference::multiply<T>(args[0]->get_data_ptr<const T>(), reference::multiply<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
multiply->get_autob());
break; break;
} }
case OP_TYPEID::Negative: case OP_TYPEID::Negative:
...@@ -1068,11 +1104,13 @@ private: ...@@ -1068,11 +1104,13 @@ private:
} }
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto not_equal = static_cast<const op::NotEqual*>(&node);
reference::not_equal<T>(args[0]->get_data_ptr<const T>(), reference::not_equal<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(), out[0]->get_data_ptr<char>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
not_equal->get_autob());
break; break;
} }
case OP_TYPEID::OneHot: case OP_TYPEID::OneHot:
...@@ -1087,11 +1125,13 @@ private: ...@@ -1087,11 +1125,13 @@ private:
} }
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto logical_or = static_cast<const op::Or*>(&node);
reference::logical_or(args[0]->get_data_ptr<const T>(), reference::logical_or(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
logical_or->get_autob());
break; break;
} }
case OP_TYPEID::Parameter: break; case OP_TYPEID::Parameter: break;
...@@ -1116,11 +1156,13 @@ private: ...@@ -1116,11 +1156,13 @@ private:
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto power = static_cast<const op::Power*>(&node);
reference::power<T>(args[0]->get_data_ptr<const T>(), reference::power<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
power->get_autob());
break; break;
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
...@@ -1555,11 +1597,13 @@ private: ...@@ -1555,11 +1597,13 @@ private:
} }
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto subtract = static_cast<const op::Subtract*>(&node);
reference::subtract<T>(args[0]->get_data_ptr<const T>(), reference::subtract<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
subtract->get_autob());
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
...@@ -1619,11 +1663,13 @@ private: ...@@ -1619,11 +1663,13 @@ private:
} }
case OP_TYPEID::Xor: case OP_TYPEID::Xor:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); auto logical_xor = static_cast<const op::Or*>(&node);
reference::logical_xor(args[0]->get_data_ptr<const T>(), reference::logical_xor(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
element_count); node.get_input_shape(0),
node.get_input_shape(1),
logical_xor->get_autob());
break; break;
} }
case OP_TYPEID::DynBroadcast: case OP_TYPEID::DynBroadcast:
......
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include <cstddef> #include <cstddef>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/runtime/reference/autobroadcast_binop.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -32,6 +37,20 @@ namespace ngraph ...@@ -32,6 +37,20 @@ namespace ngraph
out[i] = arg0[i] + arg1[i]; out[i] = arg0[i] + arg1[i];
} }
} }
template <typename T>
void add(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x + y;
});
}
} }
} }
} }
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = static_cast<T>(arg0[i] && arg1[i]); out[i] = static_cast<T>(arg0[i] && arg1[i]);
} }
} }
template <typename T>
void logical_and(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return static_cast<T>(x && y);
});
}
} }
} }
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
/// \brief Helper function to implement autobroadcasting elementwise binop references.
///
/// \tparam T Element type of the input tensors.
/// \tparam U Element type of the output tensor.
/// \tparam Functor Type of the functor for the elementwise operation. Must support
/// operator()(T,T), and operator()(T,T) must return a value of type
/// U.
///
/// \param arg0 Pointer to the buffer for left operand input tensor.
/// \param arg1 Pointer to the buffer for right operand input tensor.
/// \param out Pointer to the buffer for output tensor. This must be pre-allocated by
/// the caller, and must be large enough to hold a tensor of the correct
/// shape.
/// \param broadcast_spec Specification of the auto-broadcasting scheme.
/// \param elementwise_functor Functor implementing the elementwise operation to be
/// applied across the input tensors. Must accept two
/// arguments of type T, and return a value of type U.
template <typename T, typename U, typename Functor>
void autobroadcast_binop(const T* arg0,
const T* arg1,
U* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec,
Functor elementwise_functor)
{
switch (broadcast_spec.m_type)
{
case op::AutoBroadcastType::NONE:
for (size_t i = 0; i < shape_size(arg0_shape); i++)
{
out[i] = elementwise_functor(arg0[i], arg1[i]);
}
break;
case op::AutoBroadcastType::NUMPY:
// We'll be using CoordinateTransform to handle the broadcasting. The general
// procedure is as follows:
//
// (1) Left pad the shorter of the two shapes with ones.
// (2) Squeeze (remove ones from) both shapes, and record the squeezed axis
// indices.
// (3) Using CoordinateTransform, broadcast both args to the final output
// shape. The "broadcasted axes" will be those that were squeezed in step
// 2.
//
// Example:
//
// Input shape->Padded shape->Squeezed Shape/Squeezed Axes
// ----------- ------------ ----------------------------
// a: [ 3, 2, 1] [ 3, 2, 1] [ 3, 2 ] {2}
// b: [ 1, 6] [ 1, 1, 6] [ 6] {0,1}
// | | |
// v v v
// Output shape
// ------------
// [ 3, 2, 6]
Shape arg0_padded_shape = arg0_shape;
Shape arg1_padded_shape = arg1_shape;
while (arg0_padded_shape.size() < arg1_padded_shape.size())
{
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
}
while (arg1_padded_shape.size() < arg0_padded_shape.size())
{
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
}
Shape arg0_squeezed_shape;
Shape arg1_squeezed_shape;
AxisSet arg0_squeezed_axes;
AxisSet arg1_squeezed_axes;
Shape output_shape;
for (size_t i = 0; i < arg0_padded_shape.size(); i++)
{
if (arg0_padded_shape[i] == 1)
{
arg0_squeezed_axes.insert(i);
}
else
{
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
}
if (arg1_padded_shape[i] == 1)
{
arg1_squeezed_axes.insert(i);
}
else
{
arg1_squeezed_shape.push_back(arg1_padded_shape[i]);
}
output_shape.push_back(arg0_padded_shape[i] == 1 ? arg1_padded_shape[i]
: arg0_padded_shape[i]);
}
CoordinateTransform arg0_transform(arg0_squeezed_shape);
CoordinateTransform arg1_transform(arg1_squeezed_shape);
CoordinateTransform output_transform(output_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate arg0_coord = reduce(output_coord, arg0_squeezed_axes);
Coordinate arg1_coord = reduce(output_coord, arg1_squeezed_axes);
out[output_transform.index(output_coord)] =
elementwise_functor(arg0[arg0_transform.index(arg0_coord)],
arg1[arg1_transform.index(arg1_coord)]);
}
}
}
}
}
}
...@@ -67,6 +67,47 @@ namespace ngraph ...@@ -67,6 +67,47 @@ namespace ngraph
} }
} }
template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type
divide(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec,
bool pythondiv)
{
auto functor = [pythondiv](T x, T y) -> T {
if (pythondiv)
{
if (y == 0)
{
throw std::domain_error("integer division by zero");
}
T quot = x / y;
T rem = x % y;
if ((rem != 0) && ((x < 0) != (y < 0)))
{
return quot - 1;
}
else
{
return quot;
}
}
else
{
if (y == 0)
{
throw std::domain_error("integer division by zero");
}
return x / y;
}
};
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, functor);
}
// In English: return type is void and T must be a standard floating point type, or // In English: return type is void and T must be a standard floating point type, or
// bfloat16, or float16. // bfloat16, or float16.
template <typename T> template <typename T>
...@@ -83,6 +124,25 @@ namespace ngraph ...@@ -83,6 +124,25 @@ namespace ngraph
out[i] = arg0[i] / arg1[i]; out[i] = arg0[i] / arg1[i];
} }
} }
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value ||
std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value>::type
divide(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec,
bool pythondiv)
{
(void)pythondiv;
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x / y;
});
}
} }
} }
} }
...@@ -40,6 +40,20 @@ namespace ngraph ...@@ -40,6 +40,20 @@ namespace ngraph
out[i] = arg0[i] == arg1[i]; out[i] = arg0[i] == arg1[i];
} }
} }
template <typename T>
void equal(const T* arg0,
const T* arg1,
char* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x == y;
});
}
} }
} }
} }
......
...@@ -35,6 +35,20 @@ namespace ngraph ...@@ -35,6 +35,20 @@ namespace ngraph
out[i] = arg0[i] > arg1[i]; out[i] = arg0[i] > arg1[i];
} }
} }
template <typename T>
void greater(const T* arg0,
const T* arg1,
char* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x > y;
});
}
} }
} }
} }
...@@ -35,6 +35,20 @@ namespace ngraph ...@@ -35,6 +35,20 @@ namespace ngraph
out[i] = arg0[i] >= arg1[i]; out[i] = arg0[i] >= arg1[i];
} }
} }
template <typename T>
void greater_eq(const T* arg0,
const T* arg1,
char* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x >= y;
});
}
} }
} }
} }
...@@ -35,6 +35,20 @@ namespace ngraph ...@@ -35,6 +35,20 @@ namespace ngraph
out[i] = arg0[i] < arg1[i]; out[i] = arg0[i] < arg1[i];
} }
} }
template <typename T>
void less(const T* arg0,
const T* arg1,
char* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x < y;
});
}
} }
} }
} }
...@@ -35,6 +35,20 @@ namespace ngraph ...@@ -35,6 +35,20 @@ namespace ngraph
out[i] = arg0[i] <= arg1[i]; out[i] = arg0[i] <= arg1[i];
} }
} }
template <typename T>
void less_eq(const T* arg0,
const T* arg1,
char* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x <= y;
});
}
} }
} }
} }
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = arg0[i] > arg1[i] ? arg0[i] : arg1[i]; out[i] = arg0[i] > arg1[i] ? arg0[i] : arg1[i];
} }
} }
template <typename T>
void maximum(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x > y ? x : y;
});
}
} }
} }
} }
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = arg0[i] < arg1[i] ? arg0[i] : arg1[i]; out[i] = arg0[i] < arg1[i] ? arg0[i] : arg1[i];
} }
} }
template <typename T>
void minimum(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x < y ? x : y;
});
}
} }
} }
} }
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = arg0[i] * arg1[i]; out[i] = arg0[i] * arg1[i];
} }
} }
template <typename T>
void multiply(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x * y;
});
}
} }
} }
} }
...@@ -40,6 +40,20 @@ namespace ngraph ...@@ -40,6 +40,20 @@ namespace ngraph
out[i] = arg0[i] != arg1[i]; out[i] = arg0[i] != arg1[i];
} }
} }
template <typename T>
void not_equal(const T* arg0,
const T* arg1,
char* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x != y;
});
}
} }
} }
} }
......
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = static_cast<T>(arg0[i] || arg1[i]); out[i] = static_cast<T>(arg0[i] || arg1[i]);
} }
} }
template <typename T>
void logical_or(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return static_cast<T>(x || y);
});
}
} }
} }
} }
...@@ -33,6 +33,20 @@ namespace ngraph ...@@ -33,6 +33,20 @@ namespace ngraph
out[i] = std::pow(arg0[i], arg1[i]); out[i] = std::pow(arg0[i], arg1[i]);
} }
} }
template <typename T>
void power(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return std::pow(x, y);
});
}
} }
} }
} }
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = arg0[i] - arg1[i]; out[i] = arg0[i] - arg1[i];
} }
} }
template <typename T>
void subtract(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return x - y;
});
}
} }
} }
} }
...@@ -32,6 +32,20 @@ namespace ngraph ...@@ -32,6 +32,20 @@ namespace ngraph
out[i] = static_cast<T>((arg0[i] || arg1[i]) && !(arg0[i] && arg1[i])); out[i] = static_cast<T>((arg0[i] || arg1[i]) && !(arg0[i] && arg1[i]));
} }
} }
template <typename T>
void logical_xor(const T* arg0,
const T* arg1,
T* out,
const Shape& arg0_shape,
const Shape& arg1_shape,
const op::AutoBroadcastSpec& broadcast_spec)
{
autobroadcast_binop(
arg0, arg1, out, arg0_shape, arg1_shape, broadcast_spec, [](T x, T y) -> T {
return static_cast<T>((x || y) && !(x && y));
});
}
} }
} }
} }
...@@ -168,15 +168,24 @@ static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t po ...@@ -168,15 +168,24 @@ static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t po
TEST(constant_folding, constant_unary_binary) TEST(constant_folding, constant_unary_binary)
{ {
Shape shape_in{4};
vector<int> values_a{1, 2, 3, 4}; vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4}; vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1}; vector<int> values_c{-1, -1, -1, -1};
vector<int> values_d{1, 4, 9, 16}; vector<int> values_d{1, 4, 9, 16};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a); vector<int> values_e{5, 6};
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b); vector<int> values_f{0, 10};
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c); vector<int> values_g{1, 4};
auto d = make_shared<op::Constant>(element::i32, shape_in, values_d); vector<char> values_h{0, 0, 1, 1};
vector<char> values_i{0, 1};
auto a = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_a);
auto b = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_b);
auto c = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_c);
auto d = make_shared<op::Constant>(element::i32, Shape{2, 2}, values_d);
auto e = make_shared<op::Constant>(element::i32, Shape{2}, values_e);
auto f = make_shared<op::Constant>(element::i32, Shape{2}, values_f);
auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
auto add = a + b; auto add = a + b;
auto sub = a - b; auto sub = a - b;
...@@ -187,15 +196,54 @@ TEST(constant_folding, constant_unary_binary) ...@@ -187,15 +196,54 @@ TEST(constant_folding, constant_unary_binary)
auto absn = make_shared<op::Abs>(c); auto absn = make_shared<op::Abs>(c);
auto neg = make_shared<op::Negative>(c); auto neg = make_shared<op::Negative>(c);
auto sqrt = make_shared<op::Sqrt>(d); auto sqrt = make_shared<op::Sqrt>(d);
auto add_autob_numpy = make_shared<op::Add>(a, e, op::AutoBroadcastType::NUMPY);
auto sub_autob_numpy = make_shared<op::Subtract>(a, e, op::AutoBroadcastType::NUMPY);
auto mul_autob_numpy = make_shared<op::Multiply>(a, e, op::AutoBroadcastType::NUMPY);
auto div_autob_numpy = make_shared<op::Divide>(a, g, op::AutoBroadcastType::NUMPY);
auto min_autob_numpy = make_shared<op::Minimum>(a, f, op::AutoBroadcastType::NUMPY);
auto max_autob_numpy = make_shared<op::Maximum>(a, f, op::AutoBroadcastType::NUMPY);
auto equal_autob_numpy = make_shared<op::Equal>(a, g, op::AutoBroadcastType::NUMPY);
auto not_equal_autob_numpy = make_shared<op::NotEqual>(a, g, op::AutoBroadcastType::NUMPY);
auto greater_autob_numpy = make_shared<op::Greater>(a, g, op::AutoBroadcastType::NUMPY);
auto greater_eq_autob_numpy = make_shared<op::GreaterEq>(a, g, op::AutoBroadcastType::NUMPY);
auto less_autob_numpy = make_shared<op::Less>(a, g, op::AutoBroadcastType::NUMPY);
auto less_eq_autob_numpy = make_shared<op::LessEq>(a, g, op::AutoBroadcastType::NUMPY);
auto logical_and_autob_numpy = make_shared<op::And>(h, i, op::AutoBroadcastType::NUMPY);
auto logical_or_autob_numpy = make_shared<op::Or>(h, i, op::AutoBroadcastType::NUMPY);
auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
auto neg_sqrt = make_shared<op::Sqrt>(c); auto neg_sqrt = make_shared<op::Sqrt>(c);
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg, sqrt}, auto func = make_shared<Function>(NodeVector{add,
sub,
mul,
divn,
min,
max,
absn,
neg,
sqrt,
add_autob_numpy,
sub_autob_numpy,
mul_autob_numpy,
div_autob_numpy,
min_autob_numpy,
max_autob_numpy,
equal_autob_numpy,
not_equal_autob_numpy,
greater_autob_numpy,
greater_eq_autob_numpy,
less_autob_numpy,
less_eq_autob_numpy,
logical_and_autob_numpy,
logical_or_autob_numpy,
logical_xor_autob_numpy},
ParameterVector{}); ParameterVector{});
auto f_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{}); auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f); pass_manager.run_passes(func);
//expected values //expected values
vector<int> add_expected{2, 4, 6, 8}; vector<int> add_expected{2, 4, 6, 8};
...@@ -206,17 +254,47 @@ TEST(constant_folding, constant_unary_binary) ...@@ -206,17 +254,47 @@ TEST(constant_folding, constant_unary_binary)
vector<int> max_expected{1, 2, 3, 4}; vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1}; vector<int> abs_neg_expected{1, 1, 1, 1};
vector<int> sqrt_expected{1, 2, 3, 4}; vector<int> sqrt_expected{1, 2, 3, 4};
vector<int> add_autob_numpy_expected{6, 8, 8, 10};
ASSERT_EQ(get_result_constant<int>(f, 0), add_expected); vector<int> sub_autob_numpy_expected{-4, -4, -2, -2};
ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected); vector<int> mul_autob_numpy_expected{5, 12, 15, 24};
ASSERT_EQ(get_result_constant<int>(f, 2), mul_expected); vector<int> div_autob_numpy_expected{1, 0, 3, 1};
ASSERT_EQ(get_result_constant<int>(f, 3), div_expected); vector<int> min_autob_numpy_expected{0, 2, 0, 4};
ASSERT_EQ(get_result_constant<int>(f, 4), min_expected); vector<int> max_autob_numpy_expected{1, 10, 3, 10};
ASSERT_EQ(get_result_constant<int>(f, 5), max_expected); vector<char> equal_autob_numpy_expected{1, 0, 0, 1};
ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected); vector<char> not_equal_autob_numpy_expected{0, 1, 1, 0};
ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected); vector<char> greater_autob_numpy_expected{0, 0, 1, 0};
ASSERT_EQ(get_result_constant<int>(f, 8), sqrt_expected); vector<char> greater_eq_autob_numpy_expected{1, 0, 1, 1};
ASSERT_ANY_THROW(pass_manager.run_passes(f_error)); vector<char> less_autob_numpy_expected{0, 1, 0, 0};
vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
vector<char> logical_and_autob_numpy_expected{0, 0, 0, 1};
vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
ASSERT_EQ(get_result_constant<int>(func, 2), mul_expected);
ASSERT_EQ(get_result_constant<int>(func, 3), div_expected);
ASSERT_EQ(get_result_constant<int>(func, 4), min_expected);
ASSERT_EQ(get_result_constant<int>(func, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(func, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(func, 7), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(func, 8), sqrt_expected);
ASSERT_EQ(get_result_constant<int>(func, 9), add_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 10), sub_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 11), mul_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 12), div_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 13), min_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 14), max_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 15), equal_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 16), not_equal_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 17), greater_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 18), greater_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 19), less_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 20), less_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 21), logical_and_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 22), logical_or_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 23), logical_xor_autob_numpy_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
} }
TEST(constant_folding, const_dequantize) TEST(constant_folding, const_dequantize)
......
...@@ -1172,7 +1172,7 @@ static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t po ...@@ -1172,7 +1172,7 @@ static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t po
TEST(cpu_test, constant_unary_binary) TEST(cpu_test, constant_unary_binary)
{ {
Shape shape_in{4}; Shape shape_in{2, 2};
vector<int> values_a{1, 2, 3, 4}; vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4}; vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1}; vector<int> values_c{-1, -1, -1, -1};
...@@ -1184,6 +1184,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1184,6 +1184,7 @@ TEST(cpu_test, constant_unary_binary)
vector<char> values_i{0, 0, 1, 1}; vector<char> values_i{0, 0, 1, 1};
vector<char> values_j{0, 1, 0, 1}; vector<char> values_j{0, 1, 0, 1};
vector<float> values_k{-0.1f, 0.0f, -1.5f, 2.6f}; vector<float> values_k{-0.1f, 0.0f, -1.5f, 2.6f};
vector<int> values_l{1, 2};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a); auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b); auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c); auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
...@@ -1195,6 +1196,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1195,6 +1196,7 @@ TEST(cpu_test, constant_unary_binary)
auto i = make_shared<op::Constant>(element::boolean, shape_in, values_i); auto i = make_shared<op::Constant>(element::boolean, shape_in, values_i);
auto j = make_shared<op::Constant>(element::boolean, shape_in, values_j); auto j = make_shared<op::Constant>(element::boolean, shape_in, values_j);
auto k = make_shared<op::Constant>(element::f32, shape_in, values_k); auto k = make_shared<op::Constant>(element::f32, shape_in, values_k);
auto l = make_shared<op::Constant>(element::i32, Shape{2}, values_l);
auto add = a + b; auto add = a + b;
auto sub = a - b; auto sub = a - b;
...@@ -1220,12 +1222,17 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1220,12 +1222,17 @@ TEST(cpu_test, constant_unary_binary)
auto ceil = make_shared<op::Ceiling>(k); auto ceil = make_shared<op::Ceiling>(k);
auto floor = make_shared<op::Floor>(k); auto floor = make_shared<op::Floor>(k);
auto logical_not = make_shared<op::Not>(j); auto logical_not = make_shared<op::Not>(j);
// Note: The CPU functors do not actually support autobroadcast yet; instead the pass itself
// falls back if autobroadcasting is in use. Putting this check here just to make sure the
// fallback works as expected, but if direct support for autobroadcast is added to the CPU
// folders we should add more comprehensive tests here. --amprocte
auto add_autob_numpy = make_shared<op::Add>(a, l, op::AutoBroadcastType::NUMPY);
auto func = make_shared<Function>( auto func = make_shared<Function>(
NodeVector{add, sub, mul, divn, min, max, NodeVector{add, sub, mul, divn, min, max,
absn, neg, sqrt, relu, sign, equal, absn, neg, sqrt, relu, sign, equal,
not_equal, greater, greater_eq, less, less_eq, logical_and, not_equal, greater, greater_eq, less, less_eq, logical_and,
logical_or, logical_xor, ceil, floor, logical_not}, logical_or, logical_xor, ceil, floor, logical_not, add_autob_numpy},
ParameterVector{}); ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{}); auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
...@@ -1282,6 +1289,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1282,6 +1289,7 @@ TEST(cpu_test, constant_unary_binary)
vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f}; vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f};
vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f}; vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f};
vector<char> not_expected{1, 0, 1, 0}; vector<char> not_expected{1, 0, 1, 0};
vector<int> add_autob_numpy_expected{2, 4, 4, 6};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected); ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected); ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
...@@ -1308,6 +1316,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1308,6 +1316,7 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_TRUE(test::all_close_f( ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 21), floor_expected, MIN_FLOAT_TOLERANCE_BITS)); get_result_constant<float>(func, 21), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_EQ(get_result_constant<char>(func, 22), not_expected); ASSERT_EQ(get_result_constant<char>(func, 22), not_expected);
ASSERT_EQ(get_result_constant<int>(func, 23), add_autob_numpy_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error)); ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment