Commit 15f50ce1 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

sum fix (#930)

parent 42676ed8
......@@ -141,6 +141,36 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
return prod;
}
template <typename T>
static std::shared_ptr<Node>
multiply_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
{
T sum_cnst = static_cast<T>(cnst->get_vector<T>().at(0) * multiplier);
return op::Constant::create<T>(type, Shape{}, {sum_cnst});
}
static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
return multiply_by<int>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::i8)
{
return multiply_by<char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
return multiply_by<float>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f64)
{
return multiply_by<double>(cnst->get_element_type(), multiplier, cnst);
}
return nullptr;
}
//`simplify_sum` optimizes the following case:
//sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
......@@ -164,10 +194,15 @@ static bool simplify_sum(std::shared_ptr<Node> n)
}
auto multiplier = reduction_shape_size(sum->get_reduction_axes(), broadcast->get_shape());
double sum_const_value = cnst->get_vector<double>().at(0) * multiplier;
std::shared_ptr<Node> sum_cnst =
op::Constant::create(cnst->get_element_type(), Shape{}, {sum_const_value});
auto new_node = sum_cnst;
auto sum_cnst = get_sum_constant(cnst, multiplier);
//Unsupported type
if (!sum_cnst)
{
NGRAPH_DEBUG << "unsupported type";
return false;
}
if (sum->get_shape().size() > 0)
{
ngraph::AxisSet axes{};
......@@ -175,10 +210,10 @@ static bool simplify_sum(std::shared_ptr<Node> n)
{
axes.insert(i);
}
new_node = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes);
sum_cnst = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes);
}
ngraph::replace_node(n, new_node);
ngraph::replace_node(n, sum_cnst);
return true;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment