Commit 30649733 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

get_default_order (#1452)

* get_default_order

* add a newline

* address scotts feedback

* add numeric header
parent 176e105b
...@@ -162,8 +162,7 @@ namespace ngraph ...@@ -162,8 +162,7 @@ namespace ngraph
if (node->get_shape() != node_shape_after_possible_reshaping) if (node->get_shape() != node_shape_after_possible_reshaping)
{ {
// tell reshape to examine input dimensions in order // tell reshape to examine input dimensions in order
ngraph::AxisVector order(node->get_shape().size()); ngraph::AxisVector order = ngraph::get_default_order(node->get_shape());
std::iota(order.begin(), order.end(), 0);
return_node = std::make_shared<ngraph::op::Reshape>( return_node = std::make_shared<ngraph::op::Reshape>(
return_node, order, node_shape_after_possible_reshaping); return_node, order, node_shape_after_possible_reshaping);
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -93,8 +94,7 @@ namespace ngraph ...@@ -93,8 +94,7 @@ namespace ngraph
reshape[i] = 1; reshape[i] = 1;
} }
ngraph::AxisVector order(mu->get_shape().size()); ngraph::AxisVector order = ngraph::get_default_order(mu->get_shape());
std::iota(order.begin(), order.end(), 0);
mu = std::make_shared<op::Reshape>(mu, order, reshape); mu = std::make_shared<op::Reshape>(mu, order, reshape);
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -525,8 +525,7 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints ...@@ -525,8 +525,7 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
} }
auto swap_NC = [](const shared_ptr<Node> n) { auto swap_NC = [](const shared_ptr<Node> n) {
AxisVector ax_order(n->get_shape().size()); AxisVector ax_order = ngraph::get_default_order(n->get_shape());
iota(ax_order.begin(), ax_order.end(), 0);
ax_order[0] = 1; ax_order[0] = 1;
ax_order[1] = 0; ax_order[1] = 0;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -78,8 +79,7 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -78,8 +79,7 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
shape.push_back(1); shape.push_back(1);
} }
} }
AxisVector order(zsum->get_shape().size()); auto order = ngraph::get_default_order(zsum->get_shape());
iota(order.begin(), order.end(), 0);
auto zreshape = make_shared<op::Reshape>(zsum, order, shape); auto zreshape = make_shared<op::Reshape>(zsum, order, shape);
auto adjoint = auto adjoint =
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -116,8 +117,7 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -116,8 +117,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
{ {
if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg)) if (auto rcarg = std::dynamic_pointer_cast<op::Reshape>(carg))
{ {
Shape default_shape(rcarg->get_argument(0)->get_shape().size()); auto default_shape = ngraph::get_default_order(rcarg->get_argument(0)->get_shape());
std::iota(begin(default_shape), end(default_shape), 0);
if (default_shape != rcarg->get_input_order()) if (default_shape != rcarg->get_input_order())
{ {
NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes"; NGRAPH_DEBUG << carg->get_name() << " reshape also does transposes";
......
...@@ -73,8 +73,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -73,8 +73,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false; return false;
} }
Shape do_r1(r1->get_shape().size()); auto do_r1 = ngraph::get_default_order(r1->get_shape());
std::iota(begin(do_r1), end(do_r1), 0);
if (do_r1 != r1->get_input_order()) if (do_r1 != r1->get_input_order())
{ {
...@@ -119,10 +118,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -119,10 +118,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root()); auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0)); auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
Shape do_r2(r1->get_shape().size()); auto do_r2 = ngraph::get_default_order(r1->get_shape());
std::iota(begin(do_r2), end(do_r2), 0); auto do_r1 = ngraph::get_default_order(gop->get_shape());
Shape do_r1(gop->get_shape().size());
std::iota(begin(do_r1), end(do_r1), 0);
NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order()) NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order())
<< "do_r1 = " << vector_to_string(do_r1); << "do_r1 = " << vector_to_string(do_r1);
......
...@@ -23,17 +23,10 @@ ...@@ -23,17 +23,10 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
static void get_default_order(std::vector<size_t>& order, size_t rank)
{
for (size_t i = 0; i < rank; i++)
{
order.push_back(i);
}
}
struct CollapsedDims struct CollapsedDims
{ {
std::vector<size_t> output_shape; std::vector<size_t> output_shape;
...@@ -114,8 +107,7 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph ...@@ -114,8 +107,7 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph
if (cdims.axis_set.size() == 0) if (cdims.axis_set.size() == 0)
{ {
// Null broadcast operation, replace with reshape // Null broadcast operation, replace with reshape
AxisVector axis_order; AxisVector axis_order = ngraph::get_default_order(input_shape);
get_default_order(axis_order, input_shape.size());
auto reshape = std::make_shared<op::Reshape>( auto reshape = std::make_shared<op::Reshape>(
node->get_argument(0), axis_order, Shape(cdims.output_shape)); node->get_argument(0), axis_order, Shape(cdims.output_shape));
ngraph::replace_node(n, reshape); ngraph::replace_node(n, reshape);
...@@ -124,8 +116,7 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph ...@@ -124,8 +116,7 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph
else if (output_shape.size() != cdims.output_shape.size()) else if (output_shape.size() != cdims.output_shape.size())
{ {
// Reshape arg to collapsed input_shape // Reshape arg to collapsed input_shape
AxisVector input_axis_order; AxisVector input_axis_order = ngraph::get_default_order(input_shape);
get_default_order(input_axis_order, input_shape.size());
auto reshape_input = std::make_shared<op::Reshape>( auto reshape_input = std::make_shared<op::Reshape>(
node->get_argument(0), input_axis_order, Shape(cdims.input_shape)); node->get_argument(0), input_axis_order, Shape(cdims.input_shape));
...@@ -133,8 +124,7 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph ...@@ -133,8 +124,7 @@ bool runtime::cpu::pass::CPUCollapseDims::run_on_function(std::shared_ptr<ngraph
reshape_input, Shape(cdims.output_shape), AxisSet(cdims.axis_set)); reshape_input, Shape(cdims.output_shape), AxisSet(cdims.axis_set));
// Reshape collapsed output to original output_shape // Reshape collapsed output to original output_shape
AxisVector output_axis_order; AxisVector output_axis_order = ngraph::get_default_order(cdims.output_shape);
get_default_order(output_axis_order, cdims.output_shape.size());
auto reshape_output = auto reshape_output =
std::make_shared<op::Reshape>(broadcast, output_axis_order, output_shape); std::make_shared<op::Reshape>(broadcast, output_axis_order, output_shape);
ngraph::replace_node(n, reshape_output); ngraph::replace_node(n, reshape_output);
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "cpu_fusion.hpp"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <unordered_set> #include <unordered_set>
#include "cpu_fusion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -53,6 +54,7 @@ ...@@ -53,6 +54,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/util.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> arg,
...@@ -82,9 +84,7 @@ static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, ...@@ -82,9 +84,7 @@ static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
auto io = r_w->get_input_order(); auto io = r_w->get_input_order();
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{ {
ngraph::AxisVector dio(io.size()); auto dio = ngraph::get_default_order(io);
std::iota(begin(dio), end(dio), 0);
if (io != dio) //we can't reshape and transpose at the same time if (io != dio) //we can't reshape and transpose at the same time
{ {
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order " NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
...@@ -636,8 +636,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() ...@@ -636,8 +636,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
NGRAPH_DEBUG NGRAPH_DEBUG
<< "mpattern = " << m.get_match_root()->get_name() << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count."; << "conv_bias bias shape != 1, requires reshape to match filter count.";
ngraph::AxisVector order(bias_shape.size()); auto order = ngraph::get_default_order(bias_shape);
std::iota(begin(order), end(order), 0);
auto bias_reshape = auto bias_reshape =
std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]}); std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]});
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape)); auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape));
...@@ -698,8 +697,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop() ...@@ -698,8 +697,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_bprop()
NGRAPH_DEBUG NGRAPH_DEBUG
<< "mpattern = " << m.get_match_root()->get_name() << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count."; << "conv_bias bias shape != 1, requires reshape to match filter count.";
ngraph::AxisVector order(bias_shape.size()); auto order = ngraph::get_default_order(bias_shape);
std::iota(begin(order), end(order), 0);
auto bias_reshape = std::make_shared<op::Reshape>( auto bias_reshape = std::make_shared<op::Reshape>(
bias, order, Shape{conv_bprop->get_filters_shape()[0]}); bias, order, Shape{conv_bprop->get_filters_shape()[0]});
bias_shape = bias_reshape->get_shape(); bias_shape = bias_reshape->get_shape();
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp" #include "ngraph/runtime/cpu/op/batch_dot.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp" #include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -166,8 +167,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -166,8 +167,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
const auto& data_shape = data_node->get_shape(); const auto& data_shape = data_node->get_shape();
// construct new op nodes // construct new op nodes
AxisVector data_order(data_node->get_shape().size()); auto data_order = ngraph::get_default_order(data_node->get_shape());
std::iota(begin(data_order), end(data_order), 0);
auto data_reshape_node = std::make_shared<op::Reshape>( auto data_reshape_node = std::make_shared<op::Reshape>(
data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]}); data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
......
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <algorithm>
#include <cassert> #include <cassert>
#include <deque> #include <deque>
#include <forward_list> #include <forward_list>
#include <iomanip> #include <iomanip>
#include <map> #include <map>
#include <numeric>
#include <unordered_set> #include <unordered_set>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
...@@ -449,3 +451,14 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t ...@@ -449,3 +451,14 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t
} }
} }
} }
AxisVector ngraph::get_default_order(const Shape& shape)
{
return get_default_order(shape.size());
}
AxisVector ngraph::get_default_order(size_t rank)
{
AxisVector default_order(rank);
std::iota(begin(default_order), end(default_order), 0);
return default_order;
}
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -221,6 +223,9 @@ namespace ngraph ...@@ -221,6 +223,9 @@ namespace ngraph
void aligned_free(void*); void aligned_free(void*);
size_t round_up(size_t size, size_t alignment); size_t round_up(size_t size, size_t alignment);
AxisVector get_default_order(size_t rank);
AxisVector get_default_order(const Shape& shape);
/* /*
* Return type struct for cache_fprop, with the modified fprop and bprop * Return type struct for cache_fprop, with the modified fprop and bprop
* functions * functions
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment