Unverified Commit 34b1322d authored by Pruthvi's avatar Pruthvi Committed by GitHub

pattern matcher for BatchnormFprop + mkldnn integration in the CPU emitter (#468)

* fuse dot(a,b) + c

cblas_gemm working on mlp

rebase & small fixes

enable debug output

support replacing function's outputs

* WIP pattern matching for variance

* - Added pattern matcher graph to look up variance(sub graph) in bn
- Added test case to verify the variance graph pattern

* added batch norm mean pattern matcher.

* remove reshapes

(cherry picked from commit ecad321fb1b1bc3f7facda229beb940118ca0701)

* fixed mean test to use Matcher.

* resolve merge conflict in test/pattern.cpp

* WIP bn fprop pattern

* fprop bn fusion working

* - Added unit test case to read the bn serializeed *.json file and run bn fprop fusion pass
- Added batchnorm header file and defined the bn class to emit the mkldnn kernel
- Added pattern matcher for fprop bn in CPU graph_rewrite pass

* WIP MKLDNN fprop bn emitter code

* completed fprop batchnorm kernel in CPU emitter

* fixed bug in the emitter code for fprop bn

* - Fixed copilation issues
- unit tests are passing for bn emitter fprop code

* Added support to compute fprop bn with mean annd variance as input

* resolved compilation issues

* refactored bn fprop code

* - added batchnorm src file to the CMakeFilelist
- moved bn fusion under CPU runtime/pass/cpu_fusion
- fixed compilation issue

* Resolved compilation issues in bn emitted code

* Addded debug statements in fprop bn emitted code

* added batchnorm.cpp src file

* - Added test case to test fprop batchnorm with known tensor values
- fixed bug related to defining weights in fprop bn

* - Added test case for fprop batchnorm Op
- Added test case for mean and variance pattern matcher
- Added fprop bn *.json file with input having 4dmis mb2c3h2w2
- refactored fprop bn op class

* Style fix

* - Removed Debug symbols

* - Fixed header template with correct year
- appended mkldnn.hpp in the CPU generated code

*  Addressed PR review comments
 -  added support for batchnorm op in serializer and de-serializer
 - added more sanity in bn constructor
 - renamed "BatchnormFprop" -> BatchNorm

* - Addressed PR review comments
- replaced auto with speicfic mkldnn::type in emitted bn kernel
- modified function signature to take 'eps' as double instead of <Node> type

* added missing header files, resolved compilation issue

* style fix

* Addressed PR comments
1. initilized member variables for bn in the same order as they are defined
2. renamed bn member variables to start with m_* as per coding convention
3. moved bn fusion test to test/cpu_fusion.cpp
4. style fix
5. added more checks to evaluate type and shape of inputs to bn

* Added support for EMITDECL macro for batchnorm

* - made correction to batchnorm src file name batchnorm -> batch_norm as per coding guidelines
- corrected bn copy_with_new_args() method

* Removed redundant SqrtOp support in serializer
parent 5c8f9222
...@@ -36,6 +36,7 @@ set (SRC ...@@ -36,6 +36,7 @@ set (SRC
ops/abs.cpp ops/abs.cpp
ops/add.cpp ops/add.cpp
ops/avg_pool.cpp ops/avg_pool.cpp
ops/batch_norm.cpp
ops/binary_elementwise_arithmetic.cpp ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_comparison.cpp ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp ops/binary_elementwise.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/constant.hpp"
ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
{
add_output(input->get_element_type(), m_bn_input_shape);
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm much have tensor of atleast rank 2");
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have atleast one channel axis for batch normalization");
}
if ((m_bn_mean_shape.size() != 1) && (m_bn_variance_shape.size() != 1) &&
(gamma->get_shape().size() != 1) && (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma, beta, mean, variance shoud have all rank 1");
}
// assuming input shape (N, C, H, W), check if the size of mean and
// variance are equal to channel axis
if (mean->get_shape()[0] != m_bn_input_shape[1])
{
throw ngraph_error("mean size is not equal to input channel size");
}
if (variance->get_shape()[0] != m_bn_input_shape[1])
{
throw ngraph_error("variance size is not equal to input channel size");
}
if (variance->get_shape().size() != mean->get_shape().size())
{
throw ngraph_error("mean and variance rank does not match");
}
if (gamma->get_shape().size() != beta->get_shape().size())
{
throw ngraph_error("gamma and beta rank does not match");
}
if (input->get_element_type() != mean->get_element_type())
{
throw ngraph_error("input tensor and mean element type does not match");
}
if (input->get_element_type() != variance->get_element_type())
{
throw ngraph_error("input tensor and variance element type does not match");
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
}
std::shared_ptr<ngraph::Node> ngraph::op::BatchNorm::copy_with_new_args(
const std::vector<std::shared_ptr<ngraph::Node>>& new_args) const
{
if (new_args.size() != 5)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class BatchNorm : public RequiresTensorViewArgs
{
public:
BatchNorm(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; }
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override;
private:
Shape m_bn_input_shape;
Shape m_bn_variance_shape;
Shape m_bn_mean_shape;
double m_epsilon;
};
}
}
File mode changed from 100644 to 100755
...@@ -14,16 +14,18 @@ ...@@ -14,16 +14,18 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include <numeric>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
...@@ -42,7 +44,6 @@ ...@@ -42,7 +44,6 @@
#include "ngraph/ops/select_and_scatter.hpp" #include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/slice.hpp" #include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp" #include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
...@@ -210,6 +211,85 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMatmulBias) ...@@ -210,6 +211,85 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMatmulBias)
writer << "}\n"; writer << "}\n";
} }
void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitBatchNorm)
{
const ngraph::op::BatchNorm* batchnorm = static_cast<const ngraph::op::BatchNorm*>(node);
// get the shape of all the inputs and output to batchnorm
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto mean_shape = args[3].get_shape();
auto variance_shape = args[4].get_shape();
auto result_shape = out[0].get_shape();
// get input element type
const string& et = get_mkldnn_data_type(args[2].get_element_type().c_type_string());
writer << "{\n";
writer.indent++;
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() << ">bn_weights(2);\n";
auto weights_shape = Shape{2, input_shape[1]};
// push gamma and beta
writer << "auto gamma = " << args[0].get_name() << ";\n";
writer << "auto beta = " << args[1].get_name() << ";\n";
writer << "memcpy(&bn_weights[0], gamma,"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
// get the eps value from the bn node
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n";
// Bind to CPU engine
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
// create memory descriptors
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape) << "}, " << et
<< ", memory::format::nchw);\n";
// TODO define weights by stacking gamma and beta values
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape) << "}, " << et
<< ", memory::format::nc);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, " << et
<< ", memory::format::x);\n";
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape) << "}, " << et
<< ", memory::format::x);\n";
// Define memory for the user data
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[2].get_name()
<< ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()"
<< ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << args[3].get_name() << ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, " << args[4].get_name()
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() << ");\n";
// create batchnorm descriptor
writer << "batch_normalization_forward::desc bn_fprop_desc = "
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_global_stats|use_scale_shift);\n";
// bn fprop primitive descriptor
writer << "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
// create a batchnorm fprop primitive
writer
<< "batch_normalization_forward bn_fprop = batch_normalization_forward(bn_fprop_prim_desc, "
"primitive::at(input_data),primitive::at(mean), primitive::at(variance),"
<< "primitive::at(weights), result); \n";
// create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_fprop}).wait();\n";
writer.indent--;
writer << "}\n";
}
void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitDot) void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitDot)
{ {
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node); const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node);
......
...@@ -102,6 +102,7 @@ namespace ngraph ...@@ -102,6 +102,7 @@ namespace ngraph
static void EMITTER_DECL(EmitAvgPool); static void EMITTER_DECL(EmitAvgPool);
static void EMITTER_DECL(EmitAvgPoolBackprop); static void EMITTER_DECL(EmitAvgPoolBackprop);
static void EMITTER_DECL(EmitPad); static void EMITTER_DECL(EmitPad);
static void EMITTER_DECL(EmitBatchNorm);
static void EMITTER_DECL(EmitMaxPoolBackprop); static void EMITTER_DECL(EmitMaxPoolBackprop);
static void EmitMKLDNNPreamble(codegen::CodeWriter& writer); static void EmitMKLDNNPreamble(codegen::CodeWriter& writer);
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "ngraph/ops/asin.hpp" #include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp" #include "ngraph/ops/atan.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp" #include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
...@@ -215,6 +216,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -215,6 +216,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::CPU_Emitter::EmitAvgPool}, {TI(ngraph::op::AvgPool), &runtime::cpu::CPU_Emitter::EmitAvgPool},
{TI(ngraph::op::AvgPoolBackprop), &runtime::cpu::CPU_Emitter::EmitAvgPoolBackprop}, {TI(ngraph::op::AvgPoolBackprop), &runtime::cpu::CPU_Emitter::EmitAvgPoolBackprop},
{TI(ngraph::op::Pad), &runtime::cpu::CPU_Emitter::EmitPad}, {TI(ngraph::op::Pad), &runtime::cpu::CPU_Emitter::EmitPad},
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::EmitBatchNorm},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::EmitMaxPoolBackprop}, {TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::EmitMaxPoolBackprop},
}; };
...@@ -244,7 +246,6 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -244,7 +246,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment); pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
codegen::CodeWriter writer; codegen::CodeWriter writer;
bool include_mkldnn_headers = false; bool include_mkldnn_headers = false;
...@@ -262,7 +263,6 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -262,7 +263,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
writer += writer +=
R"(// Generated by the NGraph CPU backend R"(// Generated by the NGraph CPU backend
#include <cmath> #include <cmath>
)"; )";
writer += writer +=
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <algorithm>
#include "cpu_layout_descriptor.hpp" #include "cpu_layout_descriptor.hpp"
#include <algorithm>
#include <numeric>
namespace ngraph namespace ngraph
{ {
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp" #include "ngraph/ops/max_pool.hpp"
...@@ -41,7 +42,8 @@ namespace ngraph ...@@ -41,7 +42,8 @@ namespace ngraph
TI(ngraph::op::Convolution), TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData), TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters), TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::MaxPool)}; TI(ngraph::op::MaxPool),
TI(ngraph::op::BatchNorm)};
bool IsMKLDNNOp(ngraph::Node& op) bool IsMKLDNNOp(ngraph::Node& op)
{ {
......
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "cpu_fusion.hpp" #include "cpu_fusion.hpp"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <unordered_set> #include <numeric>
#include "ngraph/graph_util.hpp" #include <unordered_set>
#include "ngraph/log.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/parameter.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, #include "ngraph/ops/reshape.hpp"
std::shared_ptr<ngraph::Node> arg, #include "ngraph/ops/sqrt.hpp"
bool& transpose_w, #include "ngraph/ops/subtract.hpp"
ngraph::Shape& shape_w) #include "ngraph/ops/sum.hpp"
{ #include "ngraph/pattern/matcher.hpp"
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape); #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
if (!r_w) #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
{
return true; //nth to do; reshape isn't a reshape static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
} std::shared_ptr<ngraph::Node> arg,
bool& transpose_w,
if (r_w->get_shape().size() != 2) ngraph::Shape& shape_w)
{ {
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix" auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
<< ngraph::vector_to_string(r_w->get_shape());
return false; if (!r_w)
} {
return true; //nth to do; reshape isn't a reshape
auto io = r_w->get_input_order(); }
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{ if (r_w->get_shape().size() != 2)
ngraph::AxisVector dio(io.size()); {
std::iota(begin(dio), end(dio), 0); NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
<< ngraph::vector_to_string(r_w->get_shape());
if (io != dio) //we can't reshape and transpose at the same time return false;
{ }
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
<< ngraph::vector_to_string(io); auto io = r_w->get_input_order();
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape()); if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape()); {
return false; ngraph::AxisVector dio(io.size());
} std::iota(begin(dio), end(dio), 0);
shape_w = r_w->get_shape(); if (io != dio) //we can't reshape and transpose at the same time
} {
else NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
{ << ngraph::vector_to_string(io);
if (io == ngraph::AxisVector{1, 0}) NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
{ NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
transpose_w = true; return false;
} }
//otherwise no-op reshape
} shape_w = r_w->get_shape();
}
return true; else
} {
if (io == ngraph::AxisVector{1, 0})
template <typename T> {
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order) transpose_w = true;
{ }
if (input.size() != order.size()) //otherwise no-op reshape
{ }
throw "input and order sizes don't match!";
} return true;
}
std::vector<T> output(input.size());
template <typename T>
for (size_t i = 0; i < order.size(); i++) static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{ {
output[i] = input.at(order.at(i)); if (input.size() != order.size())
} {
throw "input and order sizes don't match!";
return output; }
}
std::vector<T> output(input.size());
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
{ for (size_t i = 0; i < order.size(); i++)
Shape shape_w{2, 4}; {
Shape shape_x{4, 1}; output[i] = input.at(order.at(i));
Shape shape_b{1}; }
Shape shape_dot{2, 1};
return output;
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w); }
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
auto reshape_pred = [](std::shared_ptr<Node> n) { {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n)); Shape shape_w{2, 4};
}; Shape shape_x{4, 1};
Shape shape_b{1};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred); Shape shape_dot{2, 1};
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x); auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0}); auto reshape_pred = [](std::shared_ptr<Node> n) {
auto padd = pdot + pbroadcast; return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = " auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
<< m.match_root()->get_name(); auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn = nullptr; auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto mpattern = m.match_root(); auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
if (mpattern->get_element_type() != element::f32) auto padd = pdot + pbroadcast;
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!"; ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
return nn; NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
} << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto dot = mpattern->get_input_op(0); std::shared_ptr<Node> nn = nullptr;
if (dot->get_shape().size() != 2)
{ auto mpattern = m.match_root();
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!"; if (mpattern->get_element_type() != element::f32)
return nn; {
} NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
bool transpose_w = false; }
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0)) auto dot = mpattern->get_input_op(0);
{ if (dot->get_shape().size() != 2)
return nn; {
} NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
bool transpose_x = false; }
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1)) bool transpose_w = false;
{ Shape shape_arg0{pattern_map[W]->get_shape()};
return nn; if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
} {
return nn;
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W], }
pattern_map[x],
mpattern->get_input_op(1), bool transpose_x = false;
shape_arg0, Shape shape_arg1{pattern_map[x]->get_shape()};
shape_arg1, if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
transpose_w, {
transpose_x)); return nn;
return cg; }
};
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback); pattern_map[x],
this->add_matcher(m); mpattern->get_input_op(1),
} shape_arg0,
shape_arg1,
transpose_w,
transpose_x));
return cg;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
{
// construct varaiance
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N);
auto variance_label = std::make_shared<pattern::op::Label>(variance, nullptr, Nodes{variance});
auto variance_with_broadcast =
std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0});
// construct mean
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, Nodes{mean});
auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0});
auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast);
// Eps
auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0});
auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast);
auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1);
auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps);
//Gamma
auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto gamma_with_broadcast =
std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0});
auto multiply_gamma =
std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance);
//Beta
auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0});
auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma);
// This completes fprop bn pattern
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback =
[variance_label, mean_label, input, eps_label, gamma_label, beta_label](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
<< m.match_root()->get_name();
std::shared_ptr<Node> nn = nullptr;
//TODO - add assert's based on the matched node
auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " "
<< pattern_map[input]->get_shape().size();
NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " "
<< pattern_map[variance_label]->get_shape().size();
NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " "
<< pattern_map[mean_label]->get_shape().size();
NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " "
<< pattern_map[eps_label]->get_shape().size();
NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " "
<< pattern_map[gamma_label]->get_shape().size();
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " "
<< pattern_map[beta_label]->get_shape().size();
// dont fuse if the inout doesnt have 4dims
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return nn;
}
Shape bn_output_shape{m.match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()};
// get epsilon value
auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]);
double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr()));
auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon,
pattern_map[gamma_label],
pattern_map[beta_label],
pattern_map[input],
pattern_map[mean_label],
pattern_map[variance_label]));
return bn_node;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback);
this->add_matcher(m);
}
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#pragma once #pragma once
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace cpu namespace cpu
{ {
namespace pass namespace pass
{ {
class CPUFusion; class CPUFusion;
} }
} }
} }
} }
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
CPUFusion() CPUFusion()
: GraphRewrite() : GraphRewrite()
{ {
construct_gemm_pattern(); construct_gemm_pattern();
} construct_fprop_bn();
}
private:
void construct_gemm_pattern(); private:
}; void construct_gemm_pattern();
void construct_fprop_bn();
};
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/ops/asin.hpp" #include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp" #include "ngraph/ops/atan.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp" #include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
...@@ -367,6 +368,11 @@ static shared_ptr<ngraph::Function> ...@@ -367,6 +368,11 @@ static shared_ptr<ngraph::Function>
padding_below, padding_below,
padding_above); padding_above);
} }
else if (node_op == "BatchNorm")
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2], args[3], args[4]);
}
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
...@@ -840,6 +846,11 @@ static json write(const Node& n) ...@@ -840,6 +846,11 @@ static json write(const Node& n)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
} }
else if (node_op == "BatchNorm")
{
auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value();
}
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
auto tmp = dynamic_cast<const op::Broadcast*>(&n); auto tmp = dynamic_cast<const op::Broadcast*>(&n);
......
/******************************************************************************* /*******************************************************************************
* Copyright 2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <algorithm> #include <algorithm>
#include <cstdio> #include <cstdio>
#include <iostream> #include <iostream>
#include <list> #include <list>
#include <memory> #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/any.hpp"
// #include "ngraph/pattern/op/label.hpp"
#include "ngraph/file_util.hpp" //
#include "ngraph/json.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/json.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp" #include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp" #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "util/matcher.hpp" #include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "util/test_tools.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; #include "util/all_close.hpp"
using namespace std; #include "util/matcher.hpp"
#include "util/test_tools.hpp"
TEST(cpu_fusion, gemm_pattern)
{ using namespace ngraph;
Shape shape_w{2, 4}; using namespace std;
Shape shape_x{4, 1};
Shape shape_b{1}; TEST(cpu_fusion, gemm_pattern)
auto A = make_shared<op::Parameter>(element::f32, shape_w); {
auto B = make_shared<op::Parameter>(element::f32, shape_x); Shape shape_w{2, 4};
auto C = make_shared<op::Parameter>(element::f32, shape_b); Shape shape_x{4, 1};
Shape shape_b{1};
auto dot = make_shared<op::Dot>(A, B); auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0}); auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto add = dot + broadcast; auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto W = std::make_shared<pattern::op::Label>(A); auto dot = make_shared<op::Dot>(A, B);
auto x = std::make_shared<pattern::op::Label>(B); auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n)); auto W = std::make_shared<pattern::op::Label>(A);
}; auto x = std::make_shared<pattern::op::Label>(B);
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred); auto reshape_pred = [](std::shared_ptr<Node> n) {
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred); return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto pdot = make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(C); auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0}); auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto padd = pdot + pbroadcast;
auto pdot = make_shared<op::Dot>(skip_w, skip_x);
TestMatcher n(nullptr); auto b = std::make_shared<pattern::op::Label>(C);
ASSERT_TRUE(n.match(padd, add)); auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
ASSERT_EQ(n.get_pattern_map()[W], A); auto padd = pdot + pbroadcast;
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C); TestMatcher n(nullptr);
ASSERT_TRUE(n.match(padd, add));
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, W->get_shape()); ASSERT_EQ(n.get_pattern_map()[W], A);
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, x->get_shape()); ASSERT_EQ(n.get_pattern_map()[x], B);
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x); ASSERT_EQ(n.get_pattern_map()[b], C);
auto re_add = re_dot + broadcast;
ASSERT_TRUE(n.match(padd, re_add)); auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, W->get_shape());
ASSERT_EQ(n.get_pattern_map()[W], A); auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, x->get_shape());
ASSERT_EQ(n.get_pattern_map()[x], B); auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
ASSERT_EQ(n.get_pattern_map()[b], C); auto re_add = re_dot + broadcast;
ASSERT_TRUE(n.match(padd, re_add));
auto cg = ASSERT_EQ(n.get_pattern_map()[W], A);
make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false); ASSERT_EQ(n.get_pattern_map()[x], B);
} ASSERT_EQ(n.get_pattern_map()[b], C);
TEST(cpu_fusion, gemm_cpu) auto cg =
{ make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false);
Shape shapeA{3, 2}; }
Shape shapeB{2, 3};
Shape shapeC{2, 2}; TEST(cpu_fusion, gemm_cpu)
auto A = make_shared<op::Parameter>(element::f32, shapeA); {
auto B = make_shared<op::Parameter>(element::f32, shapeB); Shape shapeA{3, 2};
Shape shapeB{2, 3};
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3}); Shape shapeC{2, 2};
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2}); auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1}); auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto cg =
make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true); auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto f = make_shared<Function>(cg, op::Parameters{A, B}); auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1});
auto cg =
auto manager = runtime::Manager::get("CPU"); make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true);
auto external = manager->compile(f);
auto backend = manager->allocate_backend(); auto f = make_shared<Function>(cg, op::Parameters{A, B});
auto cf = backend->make_call_frame(external);
auto manager = runtime::Manager::get("CPU");
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA); auto external = manager->compile(f);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB); auto backend = manager->allocate_backend();
shared_ptr<runtime::TensorView> result = auto cf = backend->make_call_frame(external);
backend->make_primary_tensor_view(element::f32, shapeC);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f}; shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f}; shared_ptr<runtime::TensorView> result =
copy_data(a, dataA); backend->make_primary_tensor_view(element::f32, shapeC);
copy_data(b, dataB);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
cf->call({a, b}, {result}); vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
vector<float> expected{10, 28, 37, 109}; copy_data(a, dataA);
ASSERT_TRUE(read_vector<float>(result) == expected); copy_data(b, dataB);
}
cf->call({a, b}, {result});
TEST(cpu_fusion, cpu_fusion_pass_basic) vector<float> expected{10, 28, 37, 109};
{ ASSERT_TRUE(read_vector<float>(result) == expected);
Shape shape{}; }
Shape shape_w{2, 4};
Shape shape_x{4, 1}; TEST(cpu_fusion, cpu_fusion_pass_basic)
Shape shape_b{1}; {
auto A = make_shared<op::Parameter>(element::f32, shape_w); Shape shape{};
auto B = make_shared<op::Parameter>(element::f32, shape_x); Shape shape_w{2, 4};
auto C = make_shared<op::Parameter>(element::f32, shape_b); Shape shape_x{4, 1};
Shape shape_b{1};
auto dot = make_shared<op::Dot>(A, B); auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0}); auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto add = dot + broadcast; auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; auto dot = make_shared<op::Dot>(A, B);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto func = make_shared<Function>(graph, op::Parameters{A, B, C}); auto add = dot + broadcast;
pass_manager.run_passes(func); auto graph = make_shared<op::Abs>(add);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr); pass::Manager pass_manager;
} pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
TEST(cpu_fusion, gemm_mlp) pass_manager.run_passes(func);
{ ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json"); }
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); TEST(cpu_fusion, gemm_mlp)
shared_ptr<Function> func = ngraph::deserialize(ss); {
pass::Manager pass_manager; const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); const string json_string = file_util::read_file_to_string(json_path);
pass_manager.run_passes(func); stringstream ss(json_string);
size_t ccg = count_ops_of_type<op::MatmulBias>(func); shared_ptr<Function> func = ngraph::deserialize(ss);
ASSERT_EQ(ccg, 3); pass::Manager pass_manager;
} pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
}
//TODO: Move this test to backend_test.in.cpp once we have the INTERPRETER
// implementation for batchnorm
TEST(cpu_fusion, batchnorm_fprop_b1c2h2w2)
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::Parameters{mean, var, input, gamma, beta});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{1, 2, 2, 2});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872,
0.89177299});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.60291237, 0.59972727});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.00472505, 0.03617825});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto result = backend->make_primary_tensor_view(element::f32, shape_r);
vector<float> expected_result{-0.71498716,
1.48388731,
-0.00196938,
-0.76693159,
-0.91316032,
0.23943391,
-0.84090298,
1.51462936};
cf->call({_mean, _var, _input, _gamma, _beta}, {result});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
TEST(cpu_fusion, batchnorm_fprop_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::Parameters{mean, var, input, gamma, beta});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872,
0.89177299});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.60291237, 0.59972727});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.00472505, 0.03617825});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto result = backend->make_primary_tensor_view(element::f32, shape_r);
vector<float> expected_result{
-0.714987, 1.48389, 0.015746, -0.284436, -2.36912, 0.56806, -0.840903, 1.51463};
cf->call({_mean, _var, _input, _gamma, _beta}, {result});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(result)));
}
TEST(cpu_fusion, fuse_fprop_bn)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::BatchNorm>(func);
ASSERT_EQ(ccg, 1);
}
[{
"name" : "Function_4",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_155",
"op" : "Parameter",
"outputs" : ["Parameter_155_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_109",
"op" : "Parameter",
"outputs" : ["Parameter_109_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_108",
"op" : "Parameter",
"outputs" : ["Parameter_108_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_107",
"op" : "Parameter",
"outputs" : ["Parameter_107_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_106",
"op" : "Parameter",
"outputs" : ["Parameter_106_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_105",
"op" : "Parameter",
"outputs" : ["Parameter_105_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_137",
"op" : "Constant",
"outputs" : ["Constant_137_0"],
"shape" : [],
"value" : ["0.001"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_117",
"op" : "Constant",
"outputs" : ["Constant_117_0"],
"shape" : [3],
"value" : [ "8", "8", "8" ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_111",
"op" : "Constant",
"outputs" : ["Constant_111_0"],
"shape" : [3],
"value" : [ "8", "8", "8" ]
},
{
"input_order" : [0],
"inputs" : ["Parameter_107"],
"name" : "Reshape_148",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_148_0"]
},
{
"input_order" : [0],
"inputs" : ["Parameter_106"],
"name" : "Reshape_147",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_147_0"]
},
{
"inputs" : [ "Parameter_105", "Parameter_105" ],
"name" : "Multiply_115",
"op" : "Multiply",
"outputs" : ["Multiply_115_0"]
},
{
"inputs" : ["Parameter_105"],
"name" : "Sum_110",
"op" : "Sum",
"outputs" : ["Sum_110_0"],
"reduction_axes" : [ 0, 2, 3 ]
},
{
"inputs" : ["Parameter_105"],
"name" : "Sum_114",
"op" : "Sum",
"outputs" : ["Sum_114_0"],
"reduction_axes" : [ 0, 2, 3 ]
},
{
"axes" : [ 0, 1, 2, 3 ],
"inputs" : ["Constant_137"],
"name" : "Broadcast_138",
"op" : "Broadcast",
"outputs" : ["Broadcast_138_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_148"],
"name" : "Reshape_152",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_152_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_147"],
"name" : "Reshape_149",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_149_0"]
},
{
"inputs" : ["Multiply_115"],
"name" : "Sum_116",
"op" : "Sum",
"outputs" : ["Sum_116_0"],
"reduction_axes" : [ 0, 2, 3 ]
},
{
"inputs" : [ "Sum_110", "Constant_111" ],
"name" : "Divide_112",
"op" : "Divide",
"outputs" : ["Divide_112_0"]
},
{
"inputs" : [ "Sum_114", "Sum_114" ],
"name" : "Multiply_118",
"op" : "Multiply",
"outputs" : ["Multiply_118_0"]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_152"],
"name" : "Broadcast_153",
"op" : "Broadcast",
"outputs" : ["Broadcast_153_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_149"],
"name" : "Broadcast_150",
"op" : "Broadcast",
"outputs" : ["Broadcast_150_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"input_order" : [0],
"inputs" : ["Divide_112"],
"name" : "Reshape_113",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_113_0"]
},
{
"inputs" : [ "Multiply_118", "Constant_117" ],
"name" : "Divide_119",
"op" : "Divide",
"outputs" : ["Divide_119_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_113"],
"name" : "Reshape_143",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_143_0"]
},
{
"inputs" : [ "Sum_116", "Divide_119" ],
"name" : "Subtract_120",
"op" : "Subtract",
"outputs" : ["Subtract_120_0"]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_143"],
"name" : "Broadcast_144",
"op" : "Broadcast",
"outputs" : ["Broadcast_144_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"inputs" : [ "Subtract_120", "Constant_117" ],
"name" : "Divide_121",
"op" : "Divide",
"outputs" : ["Divide_121_0"]
},
{
"inputs" : [ "Parameter_105", "Broadcast_144" ],
"name" : "Subtract_145",
"op" : "Subtract",
"outputs" : ["Subtract_145_0"]
},
{
"input_order" : [0],
"inputs" : ["Divide_121"],
"name" : "Reshape_122",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_122_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_122"],
"name" : "Reshape_139",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_139_0"]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_139"],
"name" : "Broadcast_140",
"op" : "Broadcast",
"outputs" : ["Broadcast_140_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"inputs" : [ "Broadcast_140", "Broadcast_138" ],
"name" : "Add_141",
"op" : "Add",
"outputs" : ["Add_141_0"]
},
{
"inputs" : ["Add_141"],
"name" : "Sqrt_142",
"op" : "Sqrt",
"outputs" : ["Sqrt_142_0"]
},
{
"inputs" : [ "Subtract_145", "Sqrt_142" ],
"name" : "Divide_146",
"op" : "Divide",
"outputs" : ["Divide_146_0"]
},
{
"inputs" : [ "Divide_146", "Broadcast_150" ],
"name" : "Multiply_151",
"op" : "Multiply",
"outputs" : ["Multiply_151_0"]
},
{
"inputs" : [ "Multiply_151", "Broadcast_153" ],
"name" : "Add_154",
"op" : "Add",
"outputs" : ["Add_154_0"]
},
{
"inputs" : [ "Add_154", "Parameter_155" ],
"name" : "Multiply_156",
"op" : "Multiply",
"outputs" : ["Multiply_156_0"]
}
],
"parameters" : [
"Parameter_105", "Parameter_106", "Parameter_107", "Parameter_108",
"Parameter_109", "Parameter_155"
],
"result" : ["Multiply_156"]
}]
...@@ -21,15 +21,26 @@ ...@@ -21,15 +21,26 @@
#include <memory> #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/sqrt.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "util/matcher.hpp" #include "util/matcher.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -125,6 +136,34 @@ std::shared_ptr<pattern::op::Label> construct_sum_pattern() //for the sake of ex ...@@ -125,6 +136,34 @@ std::shared_ptr<pattern::op::Label> construct_sum_pattern() //for the sake of ex
return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate); return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
} }
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
// construct varaiance
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N);
auto variance_label = std::make_shared<pattern::op::Label>(variance, nullptr, Nodes{variance});
return variance_label;
}
static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
//construct mean;
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, Nodes{mean});
return mean_label;
}
class TestGraphRewrite : public ngraph::pass::GraphRewrite class TestGraphRewrite : public ngraph::pass::GraphRewrite
{ {
public: public:
...@@ -474,3 +513,37 @@ TEST(pattern, sum) ...@@ -474,3 +513,37 @@ TEST(pattern, sum)
ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph)); ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph); ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
} }
TEST(pattern, mean)
{
//construct mean
TestMatcher n(nullptr);
auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_graph = construct_mean_graph();
ASSERT_TRUE(n.match(mean_graph, mean));
ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}
TEST(pattern, variance)
{
//construct variance
TestMatcher n(nullptr);
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N);
auto var_graph = construct_variance_graph();
ASSERT_TRUE(n.match(var_graph, variance));
ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment