Commit 84d236ad authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

rewrite variance to use 2 pass (#520)

parent 59bdd6ee
......@@ -15,10 +15,12 @@
*******************************************************************************/
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
......@@ -80,27 +82,34 @@ namespace ngraph
const AxisSet& reduction_axes,
const bool bessel_correction)
{
auto xsum = std::make_shared<op::Sum>(node, reduction_axes);
std::shared_ptr<Node> mu = mean(node, reduction_axes);
auto x2 = node * node;
auto reshape = node->get_shape();
for (auto i : reduction_axes)
{
reshape[i] = 1;
}
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
ngraph::AxisVector order(mu->get_shape().size());
std::iota(order.begin(), order.end(), 0);
const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes);
mu = std::make_shared<op::Reshape>(mu, order, reshape);
std::shared_ptr<Node> diff = make_with_numpy_broadcast<op::Subtract>(node, mu);
auto Nconst = op::Constant::create(et, xsum->get_shape(), {N});
auto xbar2 = (xsum * xsum) / Nconst;
diff = std::make_shared<op::Sum>(diff * diff, reduction_axes);
auto diff = x2sum - xbar2;
const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes);
if (bessel_correction)
{
auto N1const = op::Constant::create(et, xsum->get_shape(), {N - 1});
auto N1const = op::Constant::create(et, diff->get_shape(), {N - 1});
return diff / N1const;
}
else
{
auto Nconst = op::Constant::create(et, diff->get_shape(), {N});
return diff / Nconst;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment