Unverified Commit 34b1322d authored by Pruthvi's avatar Pruthvi Committed by GitHub

pattern matcher for BatchnormFprop + mkldnn integration in the CPU emitter (#468)

* fuse dot(a,b) + c

cblas_gemm working on mlp

rebase & small fixes

enable debug output

support replacing function's outputs

* WIP pattern matching for variance

* - Added pattern matcher graph to look up variance(sub graph) in bn
- Added test case to verify the variance graph pattern

* added batch norm mean pattern matcher.

* remove reshapes

(cherry picked from commit ecad321fb1b1bc3f7facda229beb940118ca0701)

* fixed mean test to use Matcher.

* resolve merge conflict in test/pattern.cpp

* WIP bn fprop pattern

* fprop bn fusion working

* - Added unit test case to read the bn serializeed *.json file and run bn fprop fusion pass
- Added batchnorm header file and defined the bn class to emit the mkldnn kernel
- Added pattern matcher for fprop bn in CPU graph_rewrite pass

* WIP MKLDNN fprop bn emitter code

* completed fprop batchnorm kernel in CPU emitter

* fixed bug in the emitter code for fprop bn

* - Fixed copilation issues
- unit tests are passing for bn emitter fprop code

* Added support to compute fprop bn with mean annd variance as input

* resolved compilation issues

* refactored bn fprop code

* - added batchnorm src file to the CMakeFilelist
- moved bn fusion under CPU runtime/pass/cpu_fusion
- fixed compilation issue

* Resolved compilation issues in bn emitted code

* Addded debug statements in fprop bn emitted code

* added batchnorm.cpp src file

* - Added test case to test fprop batchnorm with known tensor values
- fixed bug related to defining weights in fprop bn

* - Added test case for fprop batchnorm Op
- Added test case for mean and variance pattern matcher
- Added fprop bn *.json file with input having 4dmis mb2c3h2w2
- refactored fprop bn op class

* Style fix

* - Removed Debug symbols

* - Fixed header template with correct year
- appended mkldnn.hpp in the CPU generated code

*  Addressed PR review comments
 -  added support for batchnorm op in serializer and de-serializer
 - added more sanity in bn constructor
 - renamed "BatchnormFprop" -> BatchNorm

* - Addressed PR review comments
- replaced auto with speicfic mkldnn::type in emitted bn kernel
- modified function signature to take 'eps' as double instead of <Node> type

* added missing header files, resolved compilation issue

* style fix

* Addressed PR comments
1. initilized member variables for bn in the same order as they are defined
2. renamed bn member variables to start with m_* as per coding convention
3. moved bn fusion test to test/cpu_fusion.cpp
4. style fix
5. added more checks to evaluate type and shape of inputs to bn

* Added support for EMITDECL macro for batchnorm

* - made correction to batchnorm src file name batchnorm -> batch_norm as per coding guidelines
- corrected bn copy_with_new_args() method

* Removed redundant SqrtOp support in serializer
parent 5c8f9222
......@@ -36,6 +36,7 @@ set (SRC
ops/abs.cpp
ops/add.cpp
ops/avg_pool.cpp
ops/batch_norm.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/constant.hpp"
ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
{
add_output(input->get_element_type(), m_bn_input_shape);
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm much have tensor of atleast rank 2");
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have atleast one channel axis for batch normalization");
}
if ((m_bn_mean_shape.size() != 1) && (m_bn_variance_shape.size() != 1) &&
(gamma->get_shape().size() != 1) && (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma, beta, mean, variance shoud have all rank 1");
}
// assuming input shape (N, C, H, W), check if the size of mean and
// variance are equal to channel axis
if (mean->get_shape()[0] != m_bn_input_shape[1])
{
throw ngraph_error("mean size is not equal to input channel size");
}
if (variance->get_shape()[0] != m_bn_input_shape[1])
{
throw ngraph_error("variance size is not equal to input channel size");
}
if (variance->get_shape().size() != mean->get_shape().size())
{
throw ngraph_error("mean and variance rank does not match");
}
if (gamma->get_shape().size() != beta->get_shape().size())
{
throw ngraph_error("gamma and beta rank does not match");
}
if (input->get_element_type() != mean->get_element_type())
{
throw ngraph_error("input tensor and mean element type does not match");
}
if (input->get_element_type() != variance->get_element_type())
{
throw ngraph_error("input tensor and variance element type does not match");
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
}
std::shared_ptr<ngraph::Node> ngraph::op::BatchNorm::copy_with_new_args(
const std::vector<std::shared_ptr<ngraph::Node>>& new_args) const
{
if (new_args.size() != 5)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class BatchNorm : public RequiresTensorViewArgs
{
public:
BatchNorm(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; }
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override;
private:
Shape m_bn_input_shape;
Shape m_bn_variance_shape;
Shape m_bn_mean_shape;
double m_epsilon;
};
}
}
File mode changed from 100644 to 100755
......@@ -14,16 +14,18 @@
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <numeric>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
......@@ -42,7 +44,6 @@
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/types/element_type.hpp"
......@@ -210,6 +211,85 @@ void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitMatmulBias)
writer << "}\n";
}
void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitBatchNorm)
{
const ngraph::op::BatchNorm* batchnorm = static_cast<const ngraph::op::BatchNorm*>(node);
// get the shape of all the inputs and output to batchnorm
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto mean_shape = args[3].get_shape();
auto variance_shape = args[4].get_shape();
auto result_shape = out[0].get_shape();
// get input element type
const string& et = get_mkldnn_data_type(args[2].get_element_type().c_type_string());
writer << "{\n";
writer.indent++;
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() << ">bn_weights(2);\n";
auto weights_shape = Shape{2, input_shape[1]};
// push gamma and beta
writer << "auto gamma = " << args[0].get_name() << ";\n";
writer << "auto beta = " << args[1].get_name() << ";\n";
writer << "memcpy(&bn_weights[0], gamma,"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
// get the eps value from the bn node
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n";
// Bind to CPU engine
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
// create memory descriptors
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape) << "}, " << et
<< ", memory::format::nchw);\n";
// TODO define weights by stacking gamma and beta values
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape) << "}, " << et
<< ", memory::format::nc);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, " << et
<< ", memory::format::nchw);\n";
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, " << et
<< ", memory::format::x);\n";
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape) << "}, " << et
<< ", memory::format::x);\n";
// Define memory for the user data
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " << args[2].get_name()
<< ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()"
<< ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << args[3].get_name() << ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, " << args[4].get_name()
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name() << ");\n";
// create batchnorm descriptor
writer << "batch_normalization_forward::desc bn_fprop_desc = "
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_global_stats|use_scale_shift);\n";
// bn fprop primitive descriptor
writer << "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
// create a batchnorm fprop primitive
writer
<< "batch_normalization_forward bn_fprop = batch_normalization_forward(bn_fprop_prim_desc, "
"primitive::at(input_data),primitive::at(mean), primitive::at(variance),"
<< "primitive::at(weights), result); \n";
// create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_fprop}).wait();\n";
writer.indent--;
writer << "}\n";
}
void runtime::cpu::CPU_Emitter::EMITTER_DECL(EmitDot)
{
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(node);
......
......@@ -102,6 +102,7 @@ namespace ngraph
static void EMITTER_DECL(EmitAvgPool);
static void EMITTER_DECL(EmitAvgPoolBackprop);
static void EMITTER_DECL(EmitPad);
static void EMITTER_DECL(EmitBatchNorm);
static void EMITTER_DECL(EmitMaxPoolBackprop);
static void EmitMKLDNNPreamble(codegen::CodeWriter& writer);
......
......@@ -39,6 +39,7 @@
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
......@@ -215,6 +216,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::CPU_Emitter::EmitAvgPool},
{TI(ngraph::op::AvgPoolBackprop), &runtime::cpu::CPU_Emitter::EmitAvgPoolBackprop},
{TI(ngraph::op::Pad), &runtime::cpu::CPU_Emitter::EmitPad},
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::EmitBatchNorm},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::EmitMaxPoolBackprop},
};
......@@ -244,7 +246,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.run_passes(m_function);
codegen::CodeWriter writer;
bool include_mkldnn_headers = false;
......@@ -262,7 +263,6 @@ void runtime::cpu::CPU_ExternalFunction::compile()
writer +=
R"(// Generated by the NGraph CPU backend
#include <cmath>
)";
writer +=
......
......@@ -14,9 +14,9 @@
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include "cpu_layout_descriptor.hpp"
#include <algorithm>
#include <numeric>
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include "ngraph/node.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp"
......@@ -41,7 +42,8 @@ namespace ngraph
TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters),
TI(ngraph::op::MaxPool)};
TI(ngraph::op::MaxPool),
TI(ngraph::op::BatchNorm)};
bool IsMKLDNNOp(ngraph::Node& op)
{
......
This diff is collapsed.
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUFusion;
}
}
}
}
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
}
private:
void construct_gemm_pattern();
};
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUFusion;
}
}
}
}
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
construct_fprop_bn();
}
private:
void construct_gemm_pattern();
void construct_fprop_bn();
};
......@@ -22,6 +22,7 @@
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
......@@ -367,6 +368,11 @@ static shared_ptr<ngraph::Function>
padding_below,
padding_above);
}
else if (node_op == "BatchNorm")
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2], args[3], args[4]);
}
else if (node_op == "Broadcast")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
......@@ -840,6 +846,11 @@ static json write(const Node& n)
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
}
else if (node_op == "BatchNorm")
{
auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value();
}
else if (node_op == "Broadcast")
{
auto tmp = dynamic_cast<const op::Broadcast*>(&n);
......
This diff is collapsed.
[{
"name" : "Function_4",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_155",
"op" : "Parameter",
"outputs" : ["Parameter_155_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_109",
"op" : "Parameter",
"outputs" : ["Parameter_109_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_108",
"op" : "Parameter",
"outputs" : ["Parameter_108_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_107",
"op" : "Parameter",
"outputs" : ["Parameter_107_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_106",
"op" : "Parameter",
"outputs" : ["Parameter_106_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_105",
"op" : "Parameter",
"outputs" : ["Parameter_105_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_137",
"op" : "Constant",
"outputs" : ["Constant_137_0"],
"shape" : [],
"value" : ["0.001"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_117",
"op" : "Constant",
"outputs" : ["Constant_117_0"],
"shape" : [3],
"value" : [ "8", "8", "8" ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_111",
"op" : "Constant",
"outputs" : ["Constant_111_0"],
"shape" : [3],
"value" : [ "8", "8", "8" ]
},
{
"input_order" : [0],
"inputs" : ["Parameter_107"],
"name" : "Reshape_148",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_148_0"]
},
{
"input_order" : [0],
"inputs" : ["Parameter_106"],
"name" : "Reshape_147",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_147_0"]
},
{
"inputs" : [ "Parameter_105", "Parameter_105" ],
"name" : "Multiply_115",
"op" : "Multiply",
"outputs" : ["Multiply_115_0"]
},
{
"inputs" : ["Parameter_105"],
"name" : "Sum_110",
"op" : "Sum",
"outputs" : ["Sum_110_0"],
"reduction_axes" : [ 0, 2, 3 ]
},
{
"inputs" : ["Parameter_105"],
"name" : "Sum_114",
"op" : "Sum",
"outputs" : ["Sum_114_0"],
"reduction_axes" : [ 0, 2, 3 ]
},
{
"axes" : [ 0, 1, 2, 3 ],
"inputs" : ["Constant_137"],
"name" : "Broadcast_138",
"op" : "Broadcast",
"outputs" : ["Broadcast_138_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_148"],
"name" : "Reshape_152",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_152_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_147"],
"name" : "Reshape_149",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_149_0"]
},
{
"inputs" : ["Multiply_115"],
"name" : "Sum_116",
"op" : "Sum",
"outputs" : ["Sum_116_0"],
"reduction_axes" : [ 0, 2, 3 ]
},
{
"inputs" : [ "Sum_110", "Constant_111" ],
"name" : "Divide_112",
"op" : "Divide",
"outputs" : ["Divide_112_0"]
},
{
"inputs" : [ "Sum_114", "Sum_114" ],
"name" : "Multiply_118",
"op" : "Multiply",
"outputs" : ["Multiply_118_0"]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_152"],
"name" : "Broadcast_153",
"op" : "Broadcast",
"outputs" : ["Broadcast_153_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_149"],
"name" : "Broadcast_150",
"op" : "Broadcast",
"outputs" : ["Broadcast_150_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"input_order" : [0],
"inputs" : ["Divide_112"],
"name" : "Reshape_113",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_113_0"]
},
{
"inputs" : [ "Multiply_118", "Constant_117" ],
"name" : "Divide_119",
"op" : "Divide",
"outputs" : ["Divide_119_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_113"],
"name" : "Reshape_143",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_143_0"]
},
{
"inputs" : [ "Sum_116", "Divide_119" ],
"name" : "Subtract_120",
"op" : "Subtract",
"outputs" : ["Subtract_120_0"]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_143"],
"name" : "Broadcast_144",
"op" : "Broadcast",
"outputs" : ["Broadcast_144_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"inputs" : [ "Subtract_120", "Constant_117" ],
"name" : "Divide_121",
"op" : "Divide",
"outputs" : ["Divide_121_0"]
},
{
"inputs" : [ "Parameter_105", "Broadcast_144" ],
"name" : "Subtract_145",
"op" : "Subtract",
"outputs" : ["Subtract_145_0"]
},
{
"input_order" : [0],
"inputs" : ["Divide_121"],
"name" : "Reshape_122",
"op" : "Reshape",
"output_shape" : [ 1, 3, 1, 1 ],
"outputs" : ["Reshape_122_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_122"],
"name" : "Reshape_139",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_139_0"]
},
{
"axes" : [ 0, 2, 3 ],
"inputs" : ["Reshape_139"],
"name" : "Broadcast_140",
"op" : "Broadcast",
"outputs" : ["Broadcast_140_0"],
"shape" : [ 2, 3, 2, 2 ]
},
{
"inputs" : [ "Broadcast_140", "Broadcast_138" ],
"name" : "Add_141",
"op" : "Add",
"outputs" : ["Add_141_0"]
},
{
"inputs" : ["Add_141"],
"name" : "Sqrt_142",
"op" : "Sqrt",
"outputs" : ["Sqrt_142_0"]
},
{
"inputs" : [ "Subtract_145", "Sqrt_142" ],
"name" : "Divide_146",
"op" : "Divide",
"outputs" : ["Divide_146_0"]
},
{
"inputs" : [ "Divide_146", "Broadcast_150" ],
"name" : "Multiply_151",
"op" : "Multiply",
"outputs" : ["Multiply_151_0"]
},
{
"inputs" : [ "Multiply_151", "Broadcast_153" ],
"name" : "Add_154",
"op" : "Add",
"outputs" : ["Add_154_0"]
},
{
"inputs" : [ "Add_154", "Parameter_155" ],
"name" : "Multiply_156",
"op" : "Multiply",
"outputs" : ["Multiply_156_0"]
}
],
"parameters" : [
"Parameter_105", "Parameter_106", "Parameter_107", "Parameter_108",
"Parameter_109", "Parameter_155"
],
"result" : ["Multiply_156"]
}]
......@@ -21,15 +21,26 @@
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/sqrt.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "util/matcher.hpp"
using namespace ngraph;
......@@ -125,6 +136,34 @@ std::shared_ptr<pattern::op::Label> construct_sum_pattern() //for the sake of ex
return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
}
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
// construct varaiance
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N);
auto variance_label = std::make_shared<pattern::op::Label>(variance, nullptr, Nodes{variance});
return variance_label;
}
static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
//construct mean;
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, Nodes{mean});
return mean_label;
}
class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
......@@ -474,3 +513,37 @@ TEST(pattern, sum)
ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
}
TEST(pattern, mean)
{
//construct mean
TestMatcher n(nullptr);
auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_graph = construct_mean_graph();
ASSERT_TRUE(n.match(mean_graph, mean));
ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}
TEST(pattern, variance)
{
//construct variance
TestMatcher n(nullptr);
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N);
auto var_graph = construct_variance_graph();
ASSERT_TRUE(n.match(var_graph, variance));
ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment