Commit 2f37d151 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Faster constant folding for v1::Reshape, v1::Transpose and v0::DynReshape (#4214)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 44b58722
......@@ -29,20 +29,10 @@ template <typename T, typename R>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
R dyn_reshape)
{
const Shape& out_shape = dyn_reshape->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
AxisVector input_order(constant_data->get_shape().size());
std::iota(input_order.begin(), input_order.end(), 0);
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
data_ptr,
constant_data->get_shape(),
input_order,
out_shape);
return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, data_ptr);
// v1::Reshape and v0::DynReshape do not allow data transposes.
return make_shared<op::Constant>(dyn_reshape->get_element_type(),
dyn_reshape->get_shape(),
constant_data->get_data_ptr<T>());
}
template <typename R>
......
......@@ -16,7 +16,7 @@
#include "constant_folding.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/opt_kernel/reshape.hpp"
using namespace std;
using namespace ngraph;
......@@ -31,7 +31,7 @@ shared_ptr<op::Constant> fold_constant_transpose(shared_ptr<op::Constant> consta
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
runtime::opt_kernel::reshape<T>(constant_data->get_data_ptr<T>(),
buffer.get_ptr<T>(),
constant_data->get_shape(),
input_order,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment