Commit 7851c349 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

ConstantFolding optimization (#4066)

* wip

* Use AlignedBuffer instead of vector for performance

* cleanup

* More cleanup

* Revert Constant change
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d4a12feb
......@@ -331,7 +331,6 @@ namespace ngraph
}
bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;
......@@ -435,6 +434,7 @@ namespace ngraph
Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
};
......
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/util.hpp"
namespace ngraph
......
......@@ -37,12 +37,14 @@ static shared_ptr<op::Constant>
fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
vector<T> out_vec(shape_size(reduction_node->get_shape()));
const Shape& out_shape = reduction_node->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (auto max = as_type_ptr<op::Max>(reduction_node))
{
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
reduction_node->get_shape(),
max->get_reduction_axes());
......@@ -60,7 +62,7 @@ static shared_ptr<op::Constant>
}
}
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_max->get_reduction_axes());
......@@ -68,7 +70,7 @@ static shared_ptr<op::Constant>
else if (auto min = as_type_ptr<op::Min>(reduction_node))
{
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
reduction_node->get_shape(),
min->get_reduction_axes());
......@@ -86,7 +88,7 @@ static shared_ptr<op::Constant>
}
}
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_min->get_reduction_axes());
......@@ -94,7 +96,7 @@ static shared_ptr<op::Constant>
else if (auto prod = as_type_ptr<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
reduction_node->get_shape(),
prod->get_reduction_axes());
......@@ -112,7 +114,7 @@ static shared_ptr<op::Constant>
}
}
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_prod->get_reduction_axes());
......@@ -120,7 +122,7 @@ static shared_ptr<op::Constant>
else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
reduction_node->get_shape(),
sum->get_reduction_axes());
......@@ -138,7 +140,7 @@ static shared_ptr<op::Constant>
}
}
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_sum->get_reduction_axes());
......@@ -156,7 +158,7 @@ static shared_ptr<op::Constant>
}
}
runtime::reference::mean<T>(constant->get_vector<T>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
shape_no_keep_dims,
reduce_mean->get_reduction_axes());
......@@ -170,7 +172,7 @@ static shared_ptr<op::Constant>
}
return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
}
static shared_ptr<op::Constant>
......
......@@ -56,89 +56,90 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
const Shape& out_shape = binary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(char));
char* data_ptr = buffer.get_ptr<char>();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(data_ptr);
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(binary->get_output_element_type(0), out_shape, data_ptr);
}
else
{
if (auto and_v0_node = as_type_ptr<op::v0::And>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
and_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto logical_and_node = as_type_ptr<op::v1::LogicalAnd>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
logical_and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto or_node = as_type_ptr<op::v0::Or>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto logical_or_node = as_type_ptr<op::v1::LogicalOr>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
logical_or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
logical_xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else
{
......@@ -155,155 +156,156 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
const Shape& out_shape = binary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(char));
char* data_ptr = buffer.get_ptr<char>();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(data_ptr);
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(binary->get_output_element_type(0), out_shape, data_ptr);
}
else
{
if (auto equal_v0_node = as_type_ptr<op::v0::Equal>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto equal_v1_node = as_type_ptr<op::v1::Equal>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto greater_v0_node = as_type_ptr<op::v0::Greater>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
greater_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto greater_v1_node = as_type_ptr<op::v1::Greater>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
greater_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto greater_eq_v0_node = as_type_ptr<op::v0::GreaterEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
greater_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto greater_eq_v1_node = as_type_ptr<op::v1::GreaterEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
greater_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto less_v0_node = as_type_ptr<op::v0::Less>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
less_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto less_v1_node = as_type_ptr<op::v1::Less>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
less_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto less_eq_v0_node = as_type_ptr<op::v0::LessEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
less_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto less_eq_v1_node = as_type_ptr<op::v1::LessEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
less_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto not_equal_v0_node = as_type_ptr<op::v0::NotEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
not_equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto not_equal_v1_node = as_type_ptr<op::v1::NotEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
not_equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else
{
......@@ -319,21 +321,22 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
const Shape& out_shape = binary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(Tout));
Tout* data_ptr = buffer.get_ptr<Tout>();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<Tout> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(data_ptr);
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(binary->get_output_element_type(0), out_shape, data_ptr);
}
else
{
......@@ -341,178 +344,178 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
add_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto add_v1_node = as_type_ptr<op::v1::Add>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
add_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto divide_v0_node = as_type_ptr<op::v0::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Divide> divop = as_type_ptr<op::v0::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
divide_v0_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto divide_v1_node = as_type_ptr<op::v1::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v1::Divide> divop = as_type_ptr<op::v1::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
divide_v1_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto maximum_v0_node = as_type_ptr<op::v0::Maximum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
maximum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto maximum_v1_node = as_type_ptr<op::v1::Maximum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
maximum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto minimum_v0_node = as_type_ptr<op::v0::Minimum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
minimum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto minimum_v1_node = as_type_ptr<op::v1::Minimum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
minimum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto multiply_v0_node = as_type_ptr<op::v0::Multiply>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
multiply_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto multiply_v1_node = as_type_ptr<op::v1::Multiply>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
multiply_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto power_v0_node = as_type_ptr<op::v0::Power>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Power> powop = as_type_ptr<op::v0::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
power_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto power_v1_node = as_type_ptr<op::v1::Power>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v1::Power> powop = as_type_ptr<op::v1::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
power_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
data_ptr,
a->get_shape(),
b->get_shape(),
subtract_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
}
else
{
......
......@@ -27,15 +27,16 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
shared_ptr<Node> broadcast,
NodeExecutorTy func)
{
auto out_shape = broadcast->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = broadcast->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (func != nullptr)
if (func)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(data_ptr);
func(inputs, outputs);
}
......@@ -45,7 +46,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
if (static_bcast_axes.first)
{
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
constant->get_shape(),
out_shape,
static_bcast_axes.second);
......@@ -58,7 +59,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
else if (auto broadcast_v0 = as_type_ptr<op::v0::Broadcast>(broadcast))
{
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
constant->get_shape(),
out_shape,
broadcast_v0->get_broadcast_axes());
......@@ -68,7 +69,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
throw ngraph_error("Unsupported op in broadcast constant folding.");
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_broadcast()
......
......@@ -36,16 +36,14 @@ static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op:
arg_shapes.push_back(input.get_shape());
}
std::vector<T> result_vec(shape_size(concat->get_shape()));
runtime::AlignedBuffer buffer(shape_size(concat->get_shape()) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::concat<T>(arg_bufs,
result_vec.data(),
arg_shapes,
concat->get_shape(),
concat->get_concatenation_axis());
runtime::reference::concat<T>(
arg_bufs, data_ptr, arg_shapes, concat->get_shape(), concat->get_concatenation_axis());
return make_shared<op::Constant>(
concat->get_output_element_type(0), concat->get_output_shape(0), result_vec);
concat->get_output_element_type(0), concat->get_output_shape(0), data_ptr);
}
void pass::ConstantFolding::construct_constant_concat()
......
......@@ -28,13 +28,14 @@ template <typename TI, typename TO>
shared_ptr<op::Constant> fold_constant_convert_helper1(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
auto out_shape = constant->get_shape();
vector<TO> out_vec(shape_size(out_shape));
const Shape& out_shape = constant->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(TO));
TO* data_ptr = buffer.get_ptr<TO>();
runtime::reference::convert<TI, TO>(
constant->get_vector<TI>().data(), out_vec.data(), shape_size(out_shape));
constant->get_vector<TI>().data(), data_ptr, shape_size(out_shape));
return make_shared<op::Constant>(output_element_type, out_shape, out_vec);
return make_shared<op::Constant>(output_element_type, out_shape, data_ptr);
}
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
......
......@@ -27,18 +27,19 @@ shared_ptr<op::Constant> fold_constant_dequantize(shared_ptr<op::Constant> const
shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset)
{
auto out_shape = constant->get_shape();
vector<REAL> out_vec(shape_size(out_shape));
const Shape& out_shape = constant->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(REAL));
REAL* data_ptr = buffer.get_ptr<REAL>();
runtime::reference::dequantize<QUANT, REAL>(constant->get_vector<QUANT>().data(),
scale->get_vector<REAL>().data(),
offset->get_vector<QUANT>().data(),
out_vec.data(),
data_ptr,
constant->get_shape(),
scale->get_shape(),
dequant->get_axes());
return make_shared<op::Constant>(dequant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(dequant->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_dequantize()
......
......@@ -27,16 +27,14 @@ shared_ptr<op::Constant> fold_constant_dyn_broadcast(shared_ptr<op::Constant> ar
shared_ptr<op::Constant> shape,
shared_ptr<op::Constant> axes)
{
auto out_shape = shape->get_shape_val();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = shape->get_shape_val();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::broadcast<T>(arg->get_data_ptr<T>(),
out_vec.data(),
arg->get_shape(),
out_shape,
axes->get_axis_set_val());
runtime::reference::broadcast<T>(
arg->get_data_ptr<T>(), data_ptr, arg->get_shape(), out_shape, axes->get_axis_set_val());
return make_shared<op::Constant>(arg->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(arg->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_dyn_broadcast()
......
......@@ -28,20 +28,20 @@ template <class T>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
shared_ptr<op::v1::Reshape> dyn_reshape)
{
auto out_shape = dyn_reshape->get_shape();
const Shape& out_shape = dyn_reshape->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
AxisVector input_order(constant_data->get_shape().size());
std::iota(input_order.begin(), input_order.end(), 0);
vector<T> out_vec(shape_size(out_shape));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
constant_data->get_shape(),
input_order,
out_shape);
return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_dyn_reshape()
......
......@@ -42,31 +42,34 @@ shared_ptr<op::Constant> fold_constant_dyn_slice(shared_ptr<op::Constant> data,
slice->get_shrink_axis(),
slice->get_ellipsis_mask());
vector<T> slice_out_vec(shape_size(plan.reshape_in_shape));
runtime::AlignedBuffer slice_out_buffer(shape_size(plan.reshape_in_shape) * sizeof(T));
T* slice_out_data = slice_out_buffer.get_ptr<T>();
runtime::reference::slice<T>(data->get_data_ptr<T>(),
slice_out_vec.data(),
slice_out_data,
data->get_shape(),
Coordinate(plan.begins.begin(), plan.begins.end()),
Coordinate(plan.ends.begin(), plan.ends.end()),
Strides(plan.strides.begin(), plan.strides.end()),
plan.reshape_in_shape);
vector<T> reshape_out_vec(shape_size(plan.reshape_out_shape));
runtime::reference::reshape<T>(slice_out_vec.data(),
reshape_out_vec.data(),
runtime::AlignedBuffer reshape_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
T* reshape_out_data = reshape_out_buffer.get_ptr<T>();
runtime::reference::reshape<T>(slice_out_data,
reshape_out_data,
plan.reshape_in_shape,
get_default_order(plan.reshape_in_shape.size()),
plan.reshape_out_shape);
vector<T> reverse_out_vec(shape_size(plan.reshape_out_shape));
runtime::reference::reverse<T>(reshape_out_vec.data(),
reverse_out_vec.data(),
runtime::AlignedBuffer reverse_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
T* reverse_out_data = reverse_out_buffer.get_ptr<T>();
runtime::reference::reverse<T>(reshape_out_data,
reverse_out_data,
plan.reshape_out_shape,
plan.reshape_out_shape,
plan.reverse_axes);
return make_shared<op::Constant>(
data->get_element_type(), plan.reshape_out_shape, reverse_out_vec);
data->get_element_type(), plan.reshape_out_shape, reverse_out_data);
}
void pass::ConstantFolding::construct_constant_dyn_slice()
......
......@@ -28,13 +28,14 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
const shared_ptr<op::Constant>& indices,
const shared_ptr<Node>& gather)
{
std::vector<T> result_vec(shape_size(gather->get_shape()));
runtime::AlignedBuffer buffer(shape_size(gather->get_shape()) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (auto gather_v1 = as_type_ptr<op::v1::Gather>(gather))
{
runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(),
result_vec.data(),
data_ptr,
data->get_shape(),
indices->get_shape(),
gather_v1->get_shape(),
......@@ -44,7 +45,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
{
runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(),
result_vec.data(),
data_ptr,
data->get_shape(),
indices->get_shape(),
gather_v0->get_shape(),
......@@ -56,7 +57,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
}
return make_shared<op::Constant>(
gather->get_output_element_type(0), gather->get_output_shape(0), result_vec);
gather->get_output_element_type(0), gather->get_output_shape(0), data_ptr);
}
template <typename T>
......
......@@ -43,12 +43,13 @@ static Shape get_shape_no_keep_dims(const AxisSet& reduction_axes, const Shape&
static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
vector<char> out_vec(shape_size(reduction_node->get_shape()));
runtime::AlignedBuffer buffer(shape_size(reduction_node->get_shape()) * sizeof(char));
char* data_ptr = buffer.get_ptr<char>();
if (auto all = as_type_ptr<::ngraph::op::All>(reduction_node))
{
runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
reduction_node->get_shape(),
all->get_reduction_axes());
......@@ -56,7 +57,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
else if (auto any = as_type_ptr<::ngraph::op::Any>(reduction_node))
{
runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
reduction_node->get_shape(),
any->get_reduction_axes());
......@@ -67,7 +68,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
const auto input_shape = reduce_and->get_input_shape(0);
runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
get_shape_no_keep_dims(reduction_axes, input_shape),
reduction_axes);
......@@ -78,7 +79,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
const auto input_shape = reduce_or->get_input_shape(0);
runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(),
data_ptr,
constant->get_output_shape(0),
get_shape_no_keep_dims(reduction_axes, input_shape),
reduction_axes);
......@@ -92,7 +93,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
}
return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
}
void pass::ConstantFolding::construct_constant_logical_reduction()
......
......@@ -26,8 +26,9 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad,
NodeExecutorTy func)
{
auto out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = pad->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
auto pad_value = std::static_pointer_cast<op::Constant>(
pad->input(1).get_source_output().get_node_shared_ptr());
......@@ -38,7 +39,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
inputs.push_back(const_cast<void*>(pad_value->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(data_ptr);
func(inputs, outputs);
}
......@@ -46,7 +47,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
{
runtime::reference::pad<T>(constant->get_data_ptr<T>(),
pad_value->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
constant->get_shape(),
out_shape,
pad->get_padding_below(),
......@@ -54,7 +55,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
pad->get_pad_mode());
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_pad()
......
......@@ -27,19 +27,20 @@ shared_ptr<op::Constant> fold_constant_quantize(shared_ptr<op::Constant> constan
shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset)
{
auto out_shape = constant->get_shape();
vector<QUANT> out_vec(shape_size(out_shape));
const Shape& out_shape = constant->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(QUANT));
QUANT* data_ptr = buffer.get_ptr<QUANT>();
runtime::reference::quantize<REAL, QUANT>(constant->get_vector<REAL>().data(),
scale->get_vector<REAL>().data(),
offset->get_vector<QUANT>().data(),
out_vec.data(),
data_ptr,
constant->get_shape(),
scale->get_shape(),
quant->get_axes(),
quant->get_round_mode());
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(quant->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_quantize()
......
......@@ -26,13 +26,12 @@ shared_ptr<op::Constant> fold_constant_range(shared_ptr<op::Constant> start,
shared_ptr<op::Constant> step,
shared_ptr<op::Range> range)
{
vector<T> out_vec(shape_size(range->get_shape()));
runtime::reference::range<T>(start->get_vector<T>().data(),
step->get_vector<T>().data(),
range->get_shape(),
out_vec.data());
runtime::AlignedBuffer buffer(shape_size(range->get_shape()) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::range<T>(
start->get_vector<T>().data(), step->get_vector<T>().data(), range->get_shape(), data_ptr);
return make_shared<op::Constant>(range->get_element_type(), range->get_shape(), out_vec);
return make_shared<op::Constant>(range->get_element_type(), range->get_shape(), data_ptr);
}
void pass::ConstantFolding::construct_constant_range()
......
......@@ -26,28 +26,29 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant
shared_ptr<op::Reshape> reshape,
NodeExecutorTy func)
{
auto out_shape = reshape->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = reshape->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (func != nullptr)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(data_ptr);
func(inputs, outputs);
}
else
{
runtime::reference::reshape<T>(constant->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
constant->get_shape(),
reshape->get_input_order(),
out_shape);
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_reshape()
......
......@@ -25,13 +25,14 @@ template <typename T>
static shared_ptr<op::Constant> fold_constant_reverse_helper(shared_ptr<op::Constant> constant,
const AxisSet& reversed_axes)
{
auto out_shape = constant->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = constant->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::reverse<T>(
constant->get_vector<T>().data(), out_vec.data(), out_shape, out_shape, reversed_axes);
constant->get_vector<T>().data(), data_ptr, out_shape, out_shape, reversed_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), out_shape, out_vec);
return make_shared<op::Constant>(constant->get_output_element_type(0), out_shape, data_ptr);
}
static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> constant,
......
......@@ -27,15 +27,16 @@ shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& se
const shared_ptr<op::Constant>& f,
const shared_ptr<Node>& select)
{
auto out_shape = select->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = select->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (auto select_v0 = as_type_ptr<op::v0::Select>(select))
{
runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(),
f->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
shape_size(out_shape));
}
else if (auto select_v1 = as_type_ptr<op::v1::Select>(select))
......@@ -43,14 +44,14 @@ shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& se
runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(),
f->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
selection->get_shape(),
t->get_shape(),
f->get_shape(),
select_v1->get_auto_broadcast());
}
return make_shared<op::Constant>(select->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(select->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_select()
......
......@@ -25,18 +25,19 @@ template <class T>
shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
shared_ptr<op::Slice> slice)
{
auto out_shape = slice->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = slice->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::slice<T>(constant->get_data_ptr<T>(),
out_vec.data(),
data_ptr,
constant->get_shape(),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
out_shape);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
}
void pass::ConstantFolding::construct_constant_slice()
......
......@@ -24,10 +24,9 @@ template <class T>
shared_ptr<op::Constant> fold_constant_squeeze(shared_ptr<op::Constant> constant,
shared_ptr<op::Squeeze> squeeze)
{
auto out_shape = squeeze->get_shape();
vector<T> out_vec(shape_size(out_shape));
out_vec = constant->get_vector<T>();
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
const Shape& out_shape = squeeze->get_shape();
return make_shared<op::Constant>(
constant->get_element_type(), out_shape, constant->get_data_ptr());
}
void pass::ConstantFolding::construct_constant_squeeze()
......
......@@ -54,31 +54,31 @@ shared_ptr<op::Constant> fold_constant_strided_slice(shared_ptr<op::Constant> da
convert_mask_to_axis_set(slice->get_shrink_axis_mask()),
convert_mask_to_axis_set(slice->get_ellipsis_mask()));
vector<T> slice_out_vec(shape_size(plan.reshape_in_shape));
runtime::AlignedBuffer slice_out_buffer(shape_size(plan.reshape_in_shape) * sizeof(T));
runtime::reference::slice<T>(data->get_data_ptr<T>(),
slice_out_vec.data(),
slice_out_buffer.get_ptr<T>(),
data->get_shape(),
Coordinate(plan.begins.begin(), plan.begins.end()),
Coordinate(plan.ends.begin(), plan.ends.end()),
Strides(plan.strides.begin(), plan.strides.end()),
plan.reshape_in_shape);
vector<T> reshape_out_vec(shape_size(plan.reshape_out_shape));
runtime::reference::reshape<T>(slice_out_vec.data(),
reshape_out_vec.data(),
runtime::AlignedBuffer reshape_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
runtime::reference::reshape<T>(slice_out_buffer.get_ptr<T>(),
reshape_out_buffer.get_ptr<T>(),
plan.reshape_in_shape,
get_default_order(plan.reshape_in_shape.size()),
plan.reshape_out_shape);
vector<T> reverse_out_vec(shape_size(plan.reshape_out_shape));
runtime::reference::reverse<T>(reshape_out_vec.data(),
reverse_out_vec.data(),
runtime::AlignedBuffer reverse_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
runtime::reference::reverse<T>(reshape_out_buffer.get_ptr<T>(),
reverse_out_buffer.get_ptr<T>(),
plan.reshape_out_shape,
plan.reshape_out_shape,
plan.reverse_axes);
return make_shared<op::Constant>(
data->get_element_type(), plan.reshape_out_shape, reverse_out_vec);
data->get_element_type(), plan.reshape_out_shape, reverse_out_buffer.get_ptr<T>());
}
void pass::ConstantFolding::construct_constant_strided_slice()
......
......@@ -26,18 +26,18 @@ shared_ptr<op::Constant> fold_constant_transpose(shared_ptr<op::Constant> consta
shared_ptr<op::Constant> constant_perm,
shared_ptr<op::Transpose> transpose)
{
auto out_shape = transpose->get_shape();
const Shape& out_shape = transpose->get_shape();
auto input_order = constant_perm->get_axis_vector_val();
vector<T> out_vec(shape_size(out_shape));
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
out_vec.data(),
buffer.get_ptr<T>(),
constant_data->get_shape(),
input_order,
out_shape);
return make_shared<op::Constant>(transpose->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(transpose->get_element_type(), out_shape, buffer.get_ptr<T>());
}
void pass::ConstantFolding::construct_constant_transpose()
......
......@@ -60,15 +60,15 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
}
}
auto out_shape = unary->get_shape();
vector<T> out_vec(shape_size(out_shape));
const Shape& out_shape = unary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
if (func != nullptr)
{
vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
outputs.push_back(buffer.get_ptr<T>());
func(inputs, outputs);
}
......@@ -77,47 +77,47 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
if (is_type<op::Abs>(unary))
{
runtime::reference::abs<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::Ceiling>(unary))
{
runtime::reference::ceiling<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::Floor>(unary))
{
runtime::reference::floor<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::v1::LogicalNot>(unary))
{
runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::Negative>(unary))
{
runtime::reference::negate<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::v0::Not>(unary))
{
runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::Relu>(unary))
{
runtime::reference::relu<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::Sign>(unary))
{
runtime::reference::sign<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else if (is_type<op::Sqrt>(unary))
{
runtime::reference::sqrt<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
}
else
{
......@@ -125,7 +125,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
}
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, buffer.get_ptr<T>());
}
void pass::ConstantFolding::construct_constant_unary()
......
......@@ -24,10 +24,9 @@ template <class T>
shared_ptr<op::Constant> fold_constant_unsqueeze(shared_ptr<op::Constant> constant,
shared_ptr<op::Unsqueeze> unsqueeze)
{
auto out_shape = unsqueeze->get_shape();
vector<T> out_vec(shape_size(out_shape));
out_vec = constant->get_vector<T>();
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
const Shape& out_shape = unsqueeze->get_shape();
return make_shared<op::Constant>(
constant->get_element_type(), out_shape, constant->get_data_ptr());
}
void pass::ConstantFolding::construct_constant_unsqueeze()
......
......@@ -32,7 +32,7 @@ bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
size_t size = shape_size(constant->get_shape());
if (size > minimum_size_of_interest)
{
if (constant->are_all_data_elements_bitwise_identical())
if (constant->get_all_data_elements_bitwise_identical())
{
auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr());
......
......@@ -37,7 +37,7 @@ public:
// Allocator objects and the allocation interfaces are owned by the
// creators of AlignedBuffers. They need to ensure that the lifetime of
// allocator exceeds the lifetime of this AlignedBuffer.
AlignedBuffer(size_t byte_size, size_t alignment, Allocator* allocator = nullptr);
AlignedBuffer(size_t byte_size, size_t alignment = 64, Allocator* allocator = nullptr);
AlignedBuffer();
~AlignedBuffer();
......@@ -47,7 +47,25 @@ public:
size_t size() const { return m_byte_size; }
void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; }
void* get_ptr() const { return m_aligned_buffer; }
void* get_ptr() { return m_aligned_buffer; }
const void* get_ptr() const { return m_aligned_buffer; }
template <typename T>
T* get_ptr()
{
return reinterpret_cast<T*>(m_aligned_buffer);
}
template <typename T>
const T* get_ptr() const
{
return reinterpret_cast<const T*>(m_aligned_buffer);
}
template <typename T>
explicit operator T*()
{
return get_ptr<T>();
}
private:
AlignedBuffer(const AlignedBuffer&) = delete;
AlignedBuffer& operator=(const AlignedBuffer&) = delete;
......
......@@ -3338,7 +3338,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Constant:
{
auto tmp = static_cast<const op::Constant*>(&n);
if (tmp->are_all_data_elements_bitwise_identical() && shape_size(tmp->get_shape()) > 0)
if (tmp->get_all_data_elements_bitwise_identical() && shape_size(tmp->get_shape()) > 0)
{
vector<string> vs;
vs.push_back(tmp->convert_value_to_string(0));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment