Unverified Commit cb431144 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into gauri/macos_cast

parents dba33d02 13bdf0ef
...@@ -419,8 +419,6 @@ set (SRC ...@@ -419,8 +419,6 @@ set (SRC
pass/pass.hpp pass/pass.hpp
pass/pass_config.cpp pass/pass_config.cpp
pass/pass_config.hpp pass/pass_config.hpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.hpp
pass/propagate_cacheability.cpp pass/propagate_cacheability.cpp
pass/propagate_cacheability.hpp pass/propagate_cacheability.hpp
pass/reshape_elimination.cpp pass/reshape_elimination.cpp
......
...@@ -650,6 +650,143 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta ...@@ -650,6 +650,143 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
this->add_matcher(m, callback); this->add_matcher(m, callback);
} }
// graph before this fusion:
// input mean var gamma beta broadcast1_input broadcast2_input
// \ \ | / / / \
// BatchNormInference Broadcast1 Broadcast2
// \ / /
// Multiply /
// \ /
// Add
// |
// Relu
//
//
// graph after this fusion:
// input mean var gamma broadcast1_input beta broadcast2_input
// \ \ | \ / \ / /
// \ \ | Mulitply1 Multiply2 /
// \ \ | / \ /
// \ \ | / newAdd
// \ \| / /
// BatchNormInferenceRelu
//
// Multiply1, Multiply2, and newAdd operate on vectors while Multiply an Add operate on multi-dimensional matrices.
// Multiply1, Multiply2, and newAdd may be folded away with constant folding pass later.
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_infer_relu_with_multiply_add()
{
auto input_shape = Shape{1, 3, 2, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, input_shape);
auto mean_shape = Shape{3};
auto mean = std::make_shared<pattern::op::Label>(element::f32, mean_shape);
auto var_shape = Shape{3};
auto var = std::make_shared<pattern::op::Label>(element::f32, var_shape);
auto gamma_shape = Shape{3};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{3};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001;
auto bn = std::make_shared<ngraph::op::BatchNormInference>(eps, gamma, beta, input, mean, var);
auto bn_label = std::make_shared<pattern::op::Label>(bn, nullptr, NodeVector{bn});
auto broadcast1_input = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto broadcast1 =
std::make_shared<ngraph::op::Broadcast>(broadcast1_input, input_shape, AxisSet{0, 2, 3});
auto broadcast1_label =
std::make_shared<pattern::op::Label>(broadcast1, nullptr, NodeVector{broadcast1});
auto multiply = std::make_shared<ngraph::op::Multiply>(bn_label, broadcast1_label);
auto multi_label =
std::make_shared<pattern::op::Label>(multiply, nullptr, NodeVector{multiply});
auto broadcast2_input = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto broadcast2 =
std::make_shared<ngraph::op::Broadcast>(broadcast2_input, input_shape, AxisSet{0, 2, 3});
auto broadcast2_label =
std::make_shared<pattern::op::Label>(broadcast2, nullptr, NodeVector{broadcast2});
auto add = std::make_shared<ngraph::op::Add>(multi_label, broadcast2_label);
auto prelu = std::make_shared<ngraph::op::Relu>(add);
auto callback = [input,
mean,
var,
gamma,
beta,
bn_label,
multi_label,
broadcast1_input,
broadcast2_input,
broadcast1_label,
broadcast2_label](pattern::Matcher& m) {
NGRAPH_DEBUG
<< "In callback for construct_batch_norm_infer_relu_with_multi_add against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto bn_match = pattern_map[bn_label];
if (bn_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Multiply isn't the only user of BatchNorm's output";
return false;
}
auto multi_match = pattern_map[multi_label];
if (multi_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Add isn't the only user of Multiply's output";
return false;
}
std::vector<size_t> vec{0};
for (auto i = 2; i < pattern_map[input]->output(0).get_shape().size(); i++)
{
vec.push_back(i);
}
AxisSet axisSet{vec};
if (std::static_pointer_cast<ngraph::op::Broadcast>(pattern_map[broadcast1_label])
->get_broadcast_axes() != axisSet ||
std::static_pointer_cast<ngraph::op::Broadcast>(pattern_map[broadcast2_label])
->get_broadcast_axes() != axisSet)
{
NGRAPH_DEBUG << "Broadcast axes is not {0, 2, ...}";
return false;
}
auto new_gamma = std::make_shared<ngraph::op::Multiply>(pattern_map[gamma],
pattern_map[broadcast1_input]);
auto new_multi = std::make_shared<ngraph::op::Multiply>(pattern_map[beta],
pattern_map[broadcast1_input]);
auto new_beta = std::make_shared<ngraph::op::Add>(new_multi, pattern_map[broadcast2_input]);
std::shared_ptr<Node> bn_relu;
if (auto bn_inference = std::dynamic_pointer_cast<ngraph::op::BatchNormInference>(bn_match))
{
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(bn_inference.get()))
{
return false;
}
bn_relu =
std::make_shared<ngraph::op::BatchNormInferenceRelu>(bn_inference->get_eps_value(),
new_gamma,
new_beta,
pattern_map[input],
pattern_map[mean],
pattern_map[var]);
}
if (bn_relu)
{
ngraph::replace_node(m.get_match_root(), bn_relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu,
"CPUFusion.BatchNormInferReluWithMultiAdd");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu() void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
......
...@@ -78,6 +78,7 @@ public: ...@@ -78,6 +78,7 @@ public:
construct_deconvolution_affine_folding_relu(); construct_deconvolution_affine_folding_relu();
} }
construct_dropout(); construct_dropout();
construct_batch_norm_infer_relu_with_multiply_add();
} }
} }
...@@ -90,6 +91,7 @@ private: ...@@ -90,6 +91,7 @@ private:
void construct_sigmoid_multiply(); void construct_sigmoid_multiply();
void construct_batch_norm_relu(); void construct_batch_norm_relu();
void construct_batch_norm_relu_global_stats(); void construct_batch_norm_relu_global_stats();
void construct_batch_norm_infer_relu_with_multiply_add();
void construct_conv_relu(); void construct_conv_relu();
void construct_conv_bias_relu(); void construct_conv_bias_relu();
void construct_conv_bias_add(); void construct_conv_bias_add();
......
...@@ -55,6 +55,7 @@ set(SRC ...@@ -55,6 +55,7 @@ set(SRC
plaidml_pass_explicit_logicals.cpp plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp plaidml_pass_lower_convolutions.cpp
plaidml_pass_prefix_reshape_elimination.cpp
plaidml_pass_replicate_combination.cpp plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp plaidml_pass_replicate_elision.cpp
plaidml_pass_winograd.cpp plaidml_pass_winograd.cpp
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp" #include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
...@@ -36,6 +35,7 @@ ...@@ -36,6 +35,7 @@
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
...@@ -44,8 +44,7 @@ namespace ...@@ -44,8 +44,7 @@ namespace
{ {
void write_debug(const ngraph::Node& op) void write_debug(const ngraph::Node& op)
{ {
PLAIDML_DEBUG << "Node: name=\"" << op.get_name() << "\" desc=\"" << op.description() PLAIDML_DEBUG << "Compiling: " << op;
<< "\"";
for (const auto& op_input : op.get_inputs()) for (const auto& op_input : op.get_inputs())
{ {
ngraph::descriptor::Tensor* tensor = op_input.get_output().get_tensor_ptr().get(); ngraph::descriptor::Tensor* tensor = op_input.get_output().get_tensor_ptr().get();
...@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable> ...@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::pass::PrefixReshapeElimination>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (pass_manager.get_pass_config().get_pass_enable("Winograd")) if (pass_manager.get_pass_config().get_pass_enable("Winograd"))
{ {
......
...@@ -163,6 +163,12 @@ ngraph::runtime::plaidml::Config ...@@ -163,6 +163,12 @@ ngraph::runtime::plaidml::Config
// So to verify that there is a non-zero-length option value, test oval_len // So to verify that there is a non-zero-length option value, test oval_len
// To verify that there is no option value, test has_oval // To verify that there is no option value, test has_oval
if (oname_begin == oname_end && !has_oval)
{
// An empty option; poor style, but advance to the next.
continue;
}
// Check for verbosity // Check for verbosity
if (is_opt("v")) if (is_opt("v"))
{ {
......
...@@ -53,7 +53,7 @@ ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg, ...@@ -53,7 +53,7 @@ ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
void ngraph::runtime::plaidml::op::Replicate::validate_and_infer_types() void ngraph::runtime::plaidml::op::Replicate::validate_and_infer_types()
{ {
const auto& arg = get_arguments().at(0); std::shared_ptr<Node> arg = get_argument(0);
Shape shape = arg->get_shape(); Shape shape = arg->get_shape();
for (auto rit = m_replication_axes.begin(), sit = shape.begin(); for (auto rit = m_replication_axes.begin(), sit = shape.begin();
rit != m_replication_axes.end(); rit != m_replication_axes.end();
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/check.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast() ...@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
auto implicit_broadcast = auto implicit_broadcast =
std::make_shared<plaidml::op::ImplicitBroadcast>(src, broadcast->get_shape()); std::make_shared<plaidml::op::ImplicitBroadcast>(src, broadcast->get_shape());
replace_node(broadcast, implicit_broadcast); // N.B. We don't use replace_node() here, since it's important to only replace the broadcast with an
// implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool result = false;
for (size_t i = 0; i < broadcast->get_output_size(); ++i)
{
for (auto& input : broadcast->output(i).get_target_inputs())
{
Node* node = input.get_node();
if (dynamic_cast<ngraph::op::util::UnaryElementwiseArithmetic*>(node) ||
dynamic_cast<ngraph::op::util::BinaryElementwiseArithmetic*>(node))
{
input.replace_source_output(implicit_broadcast->output(i));
result = true;
}
}
}
return true; NGRAPH_CHECK(result,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass");
return result;
}; };
add_matcher(std::make_shared<pattern::Matcher>(target_op), callback); add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
} }
...@@ -75,19 +75,19 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -75,19 +75,19 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
// op. Using target always works. // op. Using target always works.
AxisVector out_axes = to_axes(target, output_transpose); AxisVector out_axes = to_axes(target, output_transpose);
auto lhs = node->get_arguments().at(0); auto lhs = node->get_argument(0);
auto* lhs_transpose = to_transpose(lhs); auto* lhs_transpose = to_transpose(lhs);
if (lhs_transpose) if (lhs_transpose)
{ {
lhs = lhs_transpose->get_arguments().at(0); lhs = lhs_transpose->get_argument(0);
} }
AxisVector lhs_axes = to_axes(lhs, lhs_transpose); AxisVector lhs_axes = to_axes(lhs, lhs_transpose);
auto rhs = node->get_arguments().at(1); auto rhs = node->get_argument(1);
auto* rhs_transpose = to_transpose(rhs); auto* rhs_transpose = to_transpose(rhs);
if (rhs_transpose) if (rhs_transpose)
{ {
rhs = rhs_transpose->get_arguments().at(0); rhs = rhs_transpose->get_argument(0);
} }
AxisVector rhs_axes = to_axes(rhs, rhs_transpose); AxisVector rhs_axes = to_axes(rhs, rhs_transpose);
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/prefix_reshape_elimination.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...@@ -23,11 +23,12 @@ ...@@ -23,11 +23,12 @@
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp" #include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_implicit_broadcast.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
pass::PrefixReshapeElimination::PrefixReshapeElimination() runtime::plaidml::pass::PrefixReshapeElimination::PrefixReshapeElimination()
{ {
auto src_op = make_shared<pattern::op::Label>( auto src_op = make_shared<pattern::op::Label>(
element::i8, Shape{}, [](shared_ptr<Node>) { return true; }); element::i8, Shape{}, [](shared_ptr<Node>) { return true; });
...@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element::i8, element::i8,
Shape{}, Shape{},
[](shared_ptr<Node> node) { [](shared_ptr<Node> node) {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(node.get()); ngraph::op::Reshape* reshape = dynamic_cast<ngraph::op::Reshape*>(node.get());
if (!reshape) if (!reshape)
{ {
return false; return false;
...@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element::i8, element::i8,
Shape{}, Shape{},
[](shared_ptr<Node> node) { [](shared_ptr<Node> node) {
return pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) || return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node); pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node);
}, },
NodeVector{reshape_op}); NodeVector{reshape_op});
auto callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2)); auto src = m.get_matched_nodes().at(2);
return true; auto prefix_reshape =
std::static_pointer_cast<ngraph::op::Reshape>(m.get_matched_nodes().at(1));
auto implicit_broadcast =
std::make_shared<op::ImplicitBroadcast>(src, prefix_reshape->get_shape());
// N.B. We don't use replace_node() here, since it's important to only replace the prefix reshape with
// an implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool result = false;
for (size_t i = 0; i < prefix_reshape->get_output_size(); ++i)
{
for (auto& input : prefix_reshape->output(i).get_target_inputs())
{
Node* node = input.get_node();
if (dynamic_cast<ngraph::op::util::UnaryElementwiseArithmetic*>(node) ||
dynamic_cast<ngraph::op::util::BinaryElementwiseArithmetic*>(node))
{
input.replace_source_output(implicit_broadcast->output(i));
result = true;
}
}
}
NGRAPH_CHECK(result,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass");
return result;
}; };
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"), add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"),
callback, callback,
PassProperty::REQUIRE_STATIC_SHAPE); ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE);
} }
...@@ -20,19 +20,23 @@ ...@@ -20,19 +20,23 @@
namespace ngraph namespace ngraph
{ {
namespace pass namespace runtime
{ {
class PrefixReshapeElimination; namespace plaidml
{
namespace pass
{
class PrefixReshapeElimination;
}
}
} }
} }
// A pass to eliminate reshapes whose output shapes are the same as // A pass that matches reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes. // their input shape modulo leading size-1 axes, and replaces them with
// // ImplicitBroadcast operations (which do the same thing as a passthrough).
// N.B. This pass MUST only be used by backends that can handle the class ngraph::runtime::plaidml::pass::PrefixReshapeElimination final
// omission of leading size-1 axes, e.g. backends that implement : public ngraph::pass::GraphRewrite
// NumPy-style broadcast semantics.
class ngraph::pass::PrefixReshapeElimination final : public ngraph::pass::GraphRewrite
{ {
public: public:
PrefixReshapeElimination(); PrefixReshapeElimination();
......
...@@ -47,9 +47,9 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination() ...@@ -47,9 +47,9 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
*ait *= *uit; *ait *= *uit;
} }
replace_node(lower, replace_node(
std::make_shared<plaidml::op::Replicate>(upper->get_arguments().at(0), lower,
std::move(axes))); std::make_shared<plaidml::op::Replicate>(upper->get_argument(0), std::move(axes)));
return true; return true;
}; };
......
...@@ -74,7 +74,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision() ...@@ -74,7 +74,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
if (elidable) if (elidable)
{ {
replaced_any = true; replaced_any = true;
replace_node(replicate, replicate->get_arguments().at(0)); replace_node(replicate, replicate->get_argument(0));
} }
} }
......
...@@ -560,6 +560,136 @@ TEST(cpu_fusion, conv_bias_bprop) ...@@ -560,6 +560,136 @@ TEST(cpu_fusion, conv_bias_bprop)
ASSERT_EQ(ccg, 1); ASSERT_EQ(ccg, 1);
} }
static void test_batchnorm_multiply_add_relu(Shape input_shape)
{
auto make_bn_relu_function = [&]() {
auto c_axis = input_shape[1];
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{c_axis};
auto mean = std::make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{c_axis};
auto var = std::make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{c_axis};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{c_axis};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto bn =
std::make_shared<ngraph::op::BatchNormInference>(eps, gamma, beta, input, mean, var);
std::vector<size_t> vec{0};
for (auto i = 2; i < input_shape.size(); i++)
{
vec.push_back(i);
}
auto broadcast1_input = std::make_shared<op::Parameter>(element::f32, gamma_shape);
auto broadcast1 =
std::make_shared<ngraph::op::Broadcast>(broadcast1_input, input_shape, AxisSet(vec));
auto multiply = std::make_shared<ngraph::op::Multiply>(bn, broadcast1);
auto broadcast2_input = std::make_shared<op::Parameter>(element::f32, gamma_shape);
auto broadcast2 =
std::make_shared<ngraph::op::Broadcast>(broadcast2_input, input_shape, AxisSet(vec));
auto add = std::make_shared<ngraph::op::Add>(multiply, broadcast2);
auto relu = std::make_shared<ngraph::op::Relu>(add);
auto f = make_shared<Function>(
relu,
ParameterVector{gamma, beta, input, mean, var, broadcast1_input, broadcast2_input});
return f;
};
auto cpu_f = make_bn_relu_function();
auto int_f = make_bn_relu_function();
test::Uniform<float> rng(1.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
size_t bn_relu = count_ops_of_type<op::BatchNormInferenceRelu>(cpu_f);
ASSERT_EQ(bn_relu, 1);
}
TEST(cpu_fusion, batchnorm_multiply_add_relu)
{
test_batchnorm_multiply_add_relu(Shape{1, 3, 2, 2});
test_batchnorm_multiply_add_relu(Shape{1, 2, 2, 2, 2});
test_batchnorm_multiply_add_relu(Shape{2, 2, 2, 4, 4});
}
TEST(cpu_fusion, batchnorm_multiply_add_relu_no_fusion)
{
auto input_shape = Shape{3, 3, 2, 2};
auto make_bn_relu_function = [&]() {
auto c_axis = input_shape[1];
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{c_axis};
auto mean = std::make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{c_axis};
auto var = std::make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{c_axis};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{c_axis};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto bn =
std::make_shared<ngraph::op::BatchNormInference>(eps, gamma, beta, input, mean, var);
std::vector<size_t> vec;
for (auto i = 1; i < input_shape.size(); i++)
{
vec.push_back(i);
}
auto broadcast1_input = std::make_shared<op::Parameter>(element::f32, Shape{3});
auto broadcast1 =
std::make_shared<ngraph::op::Broadcast>(broadcast1_input, input_shape, AxisSet(vec));
auto multiply = std::make_shared<ngraph::op::Multiply>(bn, broadcast1);
auto broadcast2_input = std::make_shared<op::Parameter>(element::f32, Shape{3});
auto broadcast2 =
std::make_shared<ngraph::op::Broadcast>(broadcast2_input, input_shape, AxisSet(vec));
auto add = std::make_shared<ngraph::op::Add>(multiply, broadcast2);
auto relu = std::make_shared<ngraph::op::Relu>(add);
auto f = make_shared<Function>(
relu,
ParameterVector{gamma, beta, input, mean, var, broadcast1_input, broadcast2_input});
return f;
};
auto cpu_f = make_bn_relu_function();
auto int_f = make_bn_relu_function();
test::Uniform<float> rng(1.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
size_t bn_relu = count_ops_of_type<op::BatchNormInferenceRelu>(cpu_f);
ASSERT_EQ(bn_relu, 0);
}
TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2) TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2)
{ {
auto input_shape = Shape{1, 2, 2, 2}; auto input_shape = Shape{1, 2, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment