Commit 15f50ce1 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

sum fix (#930)

parent 42676ed8
...@@ -141,6 +141,36 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape) ...@@ -141,6 +141,36 @@ static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
return prod; return prod;
} }
template <typename T>
static std::shared_ptr<Node>
multiply_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
{
T sum_cnst = static_cast<T>(cnst->get_vector<T>().at(0) * multiplier);
return op::Constant::create<T>(type, Shape{}, {sum_cnst});
}
static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
return multiply_by<int>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::i8)
{
return multiply_by<char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
return multiply_by<float>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f64)
{
return multiply_by<double>(cnst->get_element_type(), multiplier, cnst);
}
return nullptr;
}
//`simplify_sum` optimizes the following case: //`simplify_sum` optimizes the following case:
//sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant) //sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes) //where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
...@@ -164,10 +194,15 @@ static bool simplify_sum(std::shared_ptr<Node> n) ...@@ -164,10 +194,15 @@ static bool simplify_sum(std::shared_ptr<Node> n)
} }
auto multiplier = reduction_shape_size(sum->get_reduction_axes(), broadcast->get_shape()); auto multiplier = reduction_shape_size(sum->get_reduction_axes(), broadcast->get_shape());
double sum_const_value = cnst->get_vector<double>().at(0) * multiplier; auto sum_cnst = get_sum_constant(cnst, multiplier);
std::shared_ptr<Node> sum_cnst =
op::Constant::create(cnst->get_element_type(), Shape{}, {sum_const_value}); //Unsupported type
auto new_node = sum_cnst; if (!sum_cnst)
{
NGRAPH_DEBUG << "unsupported type";
return false;
}
if (sum->get_shape().size() > 0) if (sum->get_shape().size() > 0)
{ {
ngraph::AxisSet axes{}; ngraph::AxisSet axes{};
...@@ -175,10 +210,10 @@ static bool simplify_sum(std::shared_ptr<Node> n) ...@@ -175,10 +210,10 @@ static bool simplify_sum(std::shared_ptr<Node> n)
{ {
axes.insert(i); axes.insert(i);
} }
new_node = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes); sum_cnst = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes);
} }
ngraph::replace_node(n, new_node); ngraph::replace_node(n, sum_cnst);
return true; return true;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment