Commit 04230095 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

apply perm (#1460)

parent abff494d
...@@ -33,23 +33,9 @@ ...@@ -33,23 +33,9 @@
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
template <typename T> extern template ngraph::AxisVector
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order) ngraph::apply_permutation<ngraph::AxisVector>(ngraph::AxisVector input,
{ ngraph::AxisVector order);
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern() void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
{ {
...@@ -133,8 +119,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -133,8 +119,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
return true; return true;
} }
auto perm1 = apply_permutation(do_r1, r1->get_input_order()); auto perm1 = ngraph::apply_permutation(do_r1, r1->get_input_order());
auto perm2 = apply_permutation(perm1, r2->get_input_order()); auto perm2 = ngraph::apply_permutation(perm1, r2->get_input_order());
if (perm2 == do_r1) if (perm2 == do_r1)
{ {
NGRAPH_DEBUG << "Two transposes were removed!"; NGRAPH_DEBUG << "Two transposes were removed!";
......
...@@ -56,6 +56,9 @@ ...@@ -56,6 +56,9 @@
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order);
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> arg,
bool& transpose_w, bool& transpose_w,
...@@ -108,24 +111,6 @@ static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, ...@@ -108,24 +111,6 @@ static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
return true; return true;
} }
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias() void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
{ {
Shape shape_w{2, 4}; Shape shape_w{2, 4};
...@@ -427,8 +412,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -427,8 +412,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]); std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]);
const auto& input_order = matched_reshape->get_input_order(); const auto& input_order = matched_reshape->get_input_order();
auto hoisted_reshape_output_shape = apply_permutation<Shape::value_type>( auto hoisted_reshape_output_shape =
pattern_map[pad_input]->get_shape(), input_order); ngraph::apply_permutation<Shape>(pattern_map[pad_input]->get_shape(), input_order);
auto hoisted_reshape = std::make_shared<op::Reshape>( auto hoisted_reshape = std::make_shared<op::Reshape>(
pattern_map[pad_input], pattern_map[pad_input],
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/result_vector.hpp" #include "ngraph/op/result_vector.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include <iostream> #include <iostream>
...@@ -451,6 +452,28 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t ...@@ -451,6 +452,28 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t
} }
} }
} }
template <typename T>
T ngraph::apply_permutation(T input, AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
T output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
template AxisVector ngraph::apply_permutation<AxisVector>(AxisVector input, AxisVector order);
template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order);
AxisVector ngraph::get_default_order(const Shape& shape) AxisVector ngraph::get_default_order(const Shape& shape)
{ {
return get_default_order(shape.size()); return get_default_order(shape.size());
......
...@@ -222,6 +222,8 @@ namespace ngraph ...@@ -222,6 +222,8 @@ namespace ngraph
void* aligned_alloc(size_t alignment, size_t size); void* aligned_alloc(size_t alignment, size_t size);
void aligned_free(void*); void aligned_free(void*);
size_t round_up(size_t size, size_t alignment); size_t round_up(size_t size, size_t alignment);
template <typename T>
T apply_permutation(T input, ngraph::AxisVector order);
AxisVector get_default_order(size_t rank); AxisVector get_default_order(size_t rank);
AxisVector get_default_order(const Shape& shape); AxisVector get_default_order(const Shape& shape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment