Commit 388f449b authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Add constant folding for v1 reduce ops. (#3791)

* Add constant folding for v1 reduce ops.

* Add reference/mean.hpp.

* Remove extra semicolon.

* Address PR feedback.
parent 1ad0d723
......@@ -181,6 +181,7 @@ namespace ngraph
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_mean.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/relu.hpp"
......
......@@ -19,8 +19,12 @@
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_mean.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/mean.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/sum.hpp"
......@@ -43,6 +47,24 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
max->get_reduction_axes());
}
else if (auto reduce_max = as_type_ptr<op::v1::ReduceMax>(reduction_node))
{
auto reduction_axes = reduce_max->get_reduction_axes();
auto input_shape = reduce_max->get_input_shape(0);
Shape shape_no_keep_dims;
for (size_t i = 0; i < input_shape.size(); i++)
{
if (reduction_axes.count(i) == 0)
{
shape_no_keep_dims.push_back(input_shape[i]);
}
}
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_max->get_reduction_axes());
}
else if (auto min = as_type_ptr<op::Min>(reduction_node))
{
runtime::reference::min<T>(constant->get_vector<T>().data(),
......@@ -51,6 +73,24 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
min->get_reduction_axes());
}
else if (auto reduce_min = as_type_ptr<op::v1::ReduceMin>(reduction_node))
{
auto reduction_axes = reduce_min->get_reduction_axes();
auto input_shape = reduce_min->get_input_shape(0);
Shape shape_no_keep_dims;
for (size_t i = 0; i < input_shape.size(); i++)
{
if (reduction_axes.count(i) == 0)
{
shape_no_keep_dims.push_back(input_shape[i]);
}
}
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_min->get_reduction_axes());
}
else if (auto prod = as_type_ptr<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(),
......@@ -59,6 +99,24 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
prod->get_reduction_axes());
}
else if (auto reduce_prod = as_type_ptr<op::v1::ReduceProd>(reduction_node))
{
auto reduction_axes = reduce_prod->get_reduction_axes();
auto input_shape = reduce_prod->get_input_shape(0);
Shape shape_no_keep_dims;
for (size_t i = 0; i < input_shape.size(); i++)
{
if (reduction_axes.count(i) == 0)
{
shape_no_keep_dims.push_back(input_shape[i]);
}
}
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_prod->get_reduction_axes());
}
else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
......@@ -67,6 +125,42 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
sum->get_reduction_axes());
}
else if (auto reduce_sum = as_type_ptr<op::v1::ReduceSum>(reduction_node))
{
auto reduction_axes = reduce_sum->get_reduction_axes();
auto input_shape = reduce_sum->get_input_shape(0);
Shape shape_no_keep_dims;
for (size_t i = 0; i < input_shape.size(); i++)
{
if (reduction_axes.count(i) == 0)
{
shape_no_keep_dims.push_back(input_shape[i]);
}
}
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_sum->get_reduction_axes());
}
else if (auto reduce_mean = as_type_ptr<op::v1::ReduceMean>(reduction_node))
{
auto reduction_axes = reduce_mean->get_reduction_axes();
auto input_shape = reduce_mean->get_input_shape(0);
Shape shape_no_keep_dims;
for (size_t i = 0; i < input_shape.size(); i++)
{
if (reduction_axes.count(i) == 0)
{
shape_no_keep_dims.push_back(input_shape[i]);
}
}
runtime::reference::mean<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_mean->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
......@@ -134,7 +228,12 @@ void pass::ConstantFolding::construct_constant_arithmetic_reduction()
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto is_supported_reduction = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
pattern::has_class<op::Product>()(n) || pattern::has_class<op::Sum>()(n));
pattern::has_class<op::Product>()(n) || pattern::has_class<op::Sum>()(n) ||
pattern::has_class<op::v1::ReduceMax>()(n) ||
pattern::has_class<op::v1::ReduceMin>()(n) ||
pattern::has_class<op::v1::ReduceProd>()(n) ||
pattern::has_class<op::v1::ReduceSum>()(n) ||
pattern::has_class<op::v1::ReduceMean>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/type/bfloat16.hpp"
#include "ngraph/type/float16.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void mean(const T* arg,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& reduction_axes)
{
CoordinateTransform output_transform(out_shape);
std::vector<T> cs(shape_size(out_shape));
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 0;
cs[output_transform.index(output_coord)] = 0;
}
CoordinateTransform input_transform(in_shape);
std::map<size_t, int> index_to_count_map;
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = reduce(input_coord, reduction_axes);
T x = arg[input_transform.index(input_coord)];
T& z = out[output_transform.index(output_coord)];
auto index = output_transform.index(output_coord);
if (index_to_count_map.find(index) == index_to_count_map.end())
{
index_to_count_map[index] = 1;
}
else
{
index_to_count_map[index]++;
}
if (is_finite(x) && is_finite(z))
{
T& c = cs[output_transform.index(output_coord)];
T t = z + (x - c);
c = (t - z) - (x - c);
z = t;
}
else
{
z = z + x;
}
}
for (const Coordinate& output_coord : output_transform)
{
auto count = index_to_count_map[output_transform.index(output_coord)];
out[output_transform.index(output_coord)] =
out[output_transform.index(output_coord)] / count;
}
}
}
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment