Unverified Commit 0352e218 authored by Fenglei's avatar Fenglei Committed by GitHub

Merge branch 'master' into tfl/gpu_framework_codegen

parents 8986a83e 29014bab
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/builder/reduce_ops.hpp" #include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
...@@ -80,27 +82,34 @@ namespace ngraph ...@@ -80,27 +82,34 @@ namespace ngraph
const AxisSet& reduction_axes, const AxisSet& reduction_axes,
const bool bessel_correction) const bool bessel_correction)
{ {
auto xsum = std::make_shared<op::Sum>(node, reduction_axes); std::shared_ptr<Node> mu = mean(node, reduction_axes);
auto x2 = node * node; auto reshape = node->get_shape();
for (auto i : reduction_axes)
{
reshape[i] = 1;
}
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes); ngraph::AxisVector order(mu->get_shape().size());
std::iota(order.begin(), order.end(), 0);
const auto& et = node->get_element_type(); mu = std::make_shared<op::Reshape>(mu, order, reshape);
auto N = get_num_elements(node->get_shape(), reduction_axes);
std::shared_ptr<Node> diff = make_with_numpy_broadcast<op::Subtract>(node, mu);
auto Nconst = op::Constant::create(et, xsum->get_shape(), {N}); diff = std::make_shared<op::Sum>(diff * diff, reduction_axes);
auto xbar2 = (xsum * xsum) / Nconst;
auto diff = x2sum - xbar2; const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes);
if (bessel_correction) if (bessel_correction)
{ {
auto N1const = op::Constant::create(et, xsum->get_shape(), {N - 1}); auto N1const = op::Constant::create(et, diff->get_shape(), {N - 1});
return diff / N1const; return diff / N1const;
} }
else else
{ {
auto Nconst = op::Constant::create(et, diff->get_shape(), {N});
return diff / Nconst; return diff / Nconst;
} }
} }
......
...@@ -2682,14 +2682,13 @@ namespace ngraph ...@@ -2682,14 +2682,13 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolBackprop) void CPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolBackprop)
{ {
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node); auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto max_pool_fprop_op = mpb->get_forward_op();
auto delta_shape = args[1].get_shape(); auto delta_shape = args[1].get_shape();
auto delta_rank = delta_shape.size(); auto delta_rank = delta_shape.size();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
if (delta_rank == 4 && mpb->get_window_shape().size() == 2 && if (delta_rank == 4 && mpb->get_window_shape().size() == 2 &&
args[0].get_element_type() == element::f32 && max_pool_fprop_op != nullptr) args[0].get_element_type() == element::f32)
{ {
const string& et = const string& et =
get_mkldnn_data_type(args[1].get_element_type().c_type_string()); get_mkldnn_data_type(args[1].get_element_type().c_type_string());
...@@ -2725,10 +2724,10 @@ namespace ngraph ...@@ -2725,10 +2724,10 @@ namespace ngraph
"pooling_forward::primitive_desc(" "pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_max, " << "{prop_kind::forward, algorithm::pooling_max, "
<< "max_pool_input_desc, max_pool_result_desc, {" << "max_pool_input_desc, max_pool_result_desc, {"
<< join(max_pool_fprop_op->get_window_movement_strides()) << "}, {" << join(mpb->get_window_movement_strides()) << "}, {"
<< join(max_pool_fprop_op->get_window_shape()) << "}, " << join(mpb->get_window_shape()) << "}, "
<< "{" << join(max_pool_fprop_op->get_padding_below()) << "}, " << "{" << join(mpb->get_padding_below()) << "}, "
<< "{" << join(max_pool_fprop_op->get_padding_above()) << "}, " << "{" << join(mpb->get_padding_above()) << "}, "
<< "padding_kind::zero}, cpu_engine);\n"; << "padding_kind::zero}, cpu_engine);\n";
// query the workspace from the forward primitive desc and allocates memory // query the workspace from the forward primitive desc and allocates memory
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment