Commit 04230095 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

apply perm (#1460)

parent abff494d
......@@ -33,23 +33,9 @@
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
extern template ngraph::AxisVector
ngraph::apply_permutation<ngraph::AxisVector>(ngraph::AxisVector input,
ngraph::AxisVector order);
void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
{
......@@ -133,8 +119,8 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
return true;
}
auto perm1 = apply_permutation(do_r1, r1->get_input_order());
auto perm2 = apply_permutation(perm1, r2->get_input_order());
auto perm1 = ngraph::apply_permutation(do_r1, r1->get_input_order());
auto perm2 = ngraph::apply_permutation(perm1, r2->get_input_order());
if (perm2 == do_r1)
{
NGRAPH_DEBUG << "Two transposes were removed!";
......
......@@ -56,6 +56,9 @@
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/util.hpp"
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order);
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg,
bool& transpose_w,
......@@ -108,24 +111,6 @@ static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
return true;
}
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
{
Shape shape_w{2, 4};
......@@ -427,8 +412,8 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]);
const auto& input_order = matched_reshape->get_input_order();
auto hoisted_reshape_output_shape = apply_permutation<Shape::value_type>(
pattern_map[pad_input]->get_shape(), input_order);
auto hoisted_reshape_output_shape =
ngraph::apply_permutation<Shape>(pattern_map[pad_input]->get_shape(), input_order);
auto hoisted_reshape = std::make_shared<op::Reshape>(
pattern_map[pad_input],
......
......@@ -29,6 +29,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/result_vector.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include <iostream>
......@@ -451,6 +452,28 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t
}
}
}
template <typename T>
T ngraph::apply_permutation(T input, AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
T output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
template AxisVector ngraph::apply_permutation<AxisVector>(AxisVector input, AxisVector order);
template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order);
AxisVector ngraph::get_default_order(const Shape& shape)
{
return get_default_order(shape.size());
......
......@@ -222,6 +222,8 @@ namespace ngraph
void* aligned_alloc(size_t alignment, size_t size);
void aligned_free(void*);
size_t round_up(size_t size, size_t alignment);
template <typename T>
T apply_permutation(T input, ngraph::AxisVector order);
AxisVector get_default_order(size_t rank);
AxisVector get_default_order(const Shape& shape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment