Commit 84d236ad authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

rewrite variance to use 2 pass (#520)

parent 59bdd6ee
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/builder/reduce_ops.hpp" #include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
...@@ -80,27 +82,34 @@ namespace ngraph ...@@ -80,27 +82,34 @@ namespace ngraph
const AxisSet& reduction_axes, const AxisSet& reduction_axes,
const bool bessel_correction) const bool bessel_correction)
{ {
auto xsum = std::make_shared<op::Sum>(node, reduction_axes); std::shared_ptr<Node> mu = mean(node, reduction_axes);
auto x2 = node * node; auto reshape = node->get_shape();
for (auto i : reduction_axes)
{
reshape[i] = 1;
}
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes); ngraph::AxisVector order(mu->get_shape().size());
std::iota(order.begin(), order.end(), 0);
const auto& et = node->get_element_type(); mu = std::make_shared<op::Reshape>(mu, order, reshape);
auto N = get_num_elements(node->get_shape(), reduction_axes);
std::shared_ptr<Node> diff = make_with_numpy_broadcast<op::Subtract>(node, mu);
auto Nconst = op::Constant::create(et, xsum->get_shape(), {N}); diff = std::make_shared<op::Sum>(diff * diff, reduction_axes);
auto xbar2 = (xsum * xsum) / Nconst;
auto diff = x2sum - xbar2; const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes);
if (bessel_correction) if (bessel_correction)
{ {
auto N1const = op::Constant::create(et, xsum->get_shape(), {N - 1}); auto N1const = op::Constant::create(et, diff->get_shape(), {N - 1});
return diff / N1const; return diff / N1const;
} }
else else
{ {
auto Nconst = op::Constant::create(et, diff->get_shape(), {N});
return diff / Nconst; return diff / Nconst;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment