Commit 7851c349 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

ConstantFolding optimization (#4066)

* wip

* Use AlignedBuffer instead of vector for performance

* cleanup

* More cleanup

* Revert Constant change
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d4a12feb
...@@ -331,7 +331,6 @@ namespace ngraph ...@@ -331,7 +331,6 @@ namespace ngraph
} }
bool is_constant() const override { return true; } bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
bool get_all_data_elements_bitwise_identical() const bool get_all_data_elements_bitwise_identical() const
{ {
return m_all_elements_bitwise_identical; return m_all_elements_bitwise_identical;
...@@ -435,6 +434,7 @@ namespace ngraph ...@@ -435,6 +434,7 @@ namespace ngraph
Shape m_shape{}; Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data; std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical; bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const;
Constant(const Constant&) = delete; Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete; Constant operator=(const Constant&) = delete;
}; };
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph
......
...@@ -37,12 +37,14 @@ static shared_ptr<op::Constant> ...@@ -37,12 +37,14 @@ static shared_ptr<op::Constant>
fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant, fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node) shared_ptr<Node> reduction_node)
{ {
vector<T> out_vec(shape_size(reduction_node->get_shape())); const Shape& out_shape = reduction_node->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (auto max = as_type_ptr<op::Max>(reduction_node)) if (auto max = as_type_ptr<op::Max>(reduction_node))
{ {
runtime::reference::max<T>(constant->get_vector<T>().data(), runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
reduction_node->get_shape(), reduction_node->get_shape(),
max->get_reduction_axes()); max->get_reduction_axes());
...@@ -60,7 +62,7 @@ static shared_ptr<op::Constant> ...@@ -60,7 +62,7 @@ static shared_ptr<op::Constant>
} }
} }
runtime::reference::max<T>(constant->get_vector<T>().data(), runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
shape_no_keep_dims, shape_no_keep_dims,
reduce_max->get_reduction_axes()); reduce_max->get_reduction_axes());
...@@ -68,7 +70,7 @@ static shared_ptr<op::Constant> ...@@ -68,7 +70,7 @@ static shared_ptr<op::Constant>
else if (auto min = as_type_ptr<op::Min>(reduction_node)) else if (auto min = as_type_ptr<op::Min>(reduction_node))
{ {
runtime::reference::min<T>(constant->get_vector<T>().data(), runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
reduction_node->get_shape(), reduction_node->get_shape(),
min->get_reduction_axes()); min->get_reduction_axes());
...@@ -86,7 +88,7 @@ static shared_ptr<op::Constant> ...@@ -86,7 +88,7 @@ static shared_ptr<op::Constant>
} }
} }
runtime::reference::min<T>(constant->get_vector<T>().data(), runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
shape_no_keep_dims, shape_no_keep_dims,
reduce_min->get_reduction_axes()); reduce_min->get_reduction_axes());
...@@ -94,7 +96,7 @@ static shared_ptr<op::Constant> ...@@ -94,7 +96,7 @@ static shared_ptr<op::Constant>
else if (auto prod = as_type_ptr<op::Product>(reduction_node)) else if (auto prod = as_type_ptr<op::Product>(reduction_node))
{ {
runtime::reference::product<T>(constant->get_vector<T>().data(), runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
reduction_node->get_shape(), reduction_node->get_shape(),
prod->get_reduction_axes()); prod->get_reduction_axes());
...@@ -112,7 +114,7 @@ static shared_ptr<op::Constant> ...@@ -112,7 +114,7 @@ static shared_ptr<op::Constant>
} }
} }
runtime::reference::product<T>(constant->get_vector<T>().data(), runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
shape_no_keep_dims, shape_no_keep_dims,
reduce_prod->get_reduction_axes()); reduce_prod->get_reduction_axes());
...@@ -120,7 +122,7 @@ static shared_ptr<op::Constant> ...@@ -120,7 +122,7 @@ static shared_ptr<op::Constant>
else if (auto sum = as_type_ptr<op::Sum>(reduction_node)) else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
{ {
runtime::reference::sum<T>(constant->get_vector<T>().data(), runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
reduction_node->get_shape(), reduction_node->get_shape(),
sum->get_reduction_axes()); sum->get_reduction_axes());
...@@ -138,7 +140,7 @@ static shared_ptr<op::Constant> ...@@ -138,7 +140,7 @@ static shared_ptr<op::Constant>
} }
} }
runtime::reference::sum<T>(constant->get_vector<T>().data(), runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
shape_no_keep_dims, shape_no_keep_dims,
reduce_sum->get_reduction_axes()); reduce_sum->get_reduction_axes());
...@@ -156,7 +158,7 @@ static shared_ptr<op::Constant> ...@@ -156,7 +158,7 @@ static shared_ptr<op::Constant>
} }
} }
runtime::reference::mean<T>(constant->get_vector<T>().data(), runtime::reference::mean<T>(constant->get_vector<T>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
shape_no_keep_dims, shape_no_keep_dims,
reduce_mean->get_reduction_axes()); reduce_mean->get_reduction_axes());
...@@ -170,7 +172,7 @@ static shared_ptr<op::Constant> ...@@ -170,7 +172,7 @@ static shared_ptr<op::Constant>
} }
return make_shared<op::Constant>( return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec); reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
} }
static shared_ptr<op::Constant> static shared_ptr<op::Constant>
......
...@@ -56,89 +56,90 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons ...@@ -56,89 +56,90 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
shared_ptr<Node> binary, shared_ptr<Node> binary,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = binary->get_shape(); const Shape& out_shape = binary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(char));
char* data_ptr = buffer.get_ptr<char>();
// NOTE: We will skip the executor if the shapes do not match, because that means // NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that. // auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape()) if (func != nullptr && a->get_shape() == b->get_shape())
{ {
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr())); inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr())); inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(data_ptr);
func(inputs, outputs); func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_output_element_type(0), out_shape, data_ptr);
} }
else else
{ {
if (auto and_v0_node = as_type_ptr<op::v0::And>(binary)) if (auto and_v0_node = as_type_ptr<op::v0::And>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(), runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
and_v0_node->get_autob()); and_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto logical_and_node = as_type_ptr<op::v1::LogicalAnd>(binary)) else if (auto logical_and_node = as_type_ptr<op::v1::LogicalAnd>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(), runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
logical_and_node->get_autob()); logical_and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto or_node = as_type_ptr<op::v0::Or>(binary)) else if (auto or_node = as_type_ptr<op::v0::Or>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(), runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
or_node->get_autob()); or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto logical_or_node = as_type_ptr<op::v1::LogicalOr>(binary)) else if (auto logical_or_node = as_type_ptr<op::v1::LogicalOr>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(), runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
logical_or_node->get_autob()); logical_or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary)) else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(), runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
xor_node->get_autob()); xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary)) else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(), runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
logical_xor_node->get_autob()); logical_xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else else
{ {
...@@ -155,155 +156,156 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -155,155 +156,156 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
shared_ptr<Node> binary, shared_ptr<Node> binary,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = binary->get_shape(); const Shape& out_shape = binary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(char));
char* data_ptr = buffer.get_ptr<char>();
// NOTE: We will skip the executor if the shapes do not match, because that means // NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that. // auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape()) if (func != nullptr && a->get_shape() == b->get_shape())
{ {
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr())); inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr())); inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(data_ptr);
func(inputs, outputs); func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_output_element_type(0), out_shape, data_ptr);
} }
else else
{ {
if (auto equal_v0_node = as_type_ptr<op::v0::Equal>(binary)) if (auto equal_v0_node = as_type_ptr<op::v0::Equal>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
equal_v0_node->get_autob()); equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto equal_v1_node = as_type_ptr<op::v1::Equal>(binary)) else if (auto equal_v1_node = as_type_ptr<op::v1::Equal>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
equal_v1_node->get_autob()); equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto greater_v0_node = as_type_ptr<op::v0::Greater>(binary)) else if (auto greater_v0_node = as_type_ptr<op::v0::Greater>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
greater_v0_node->get_autob()); greater_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto greater_v1_node = as_type_ptr<op::v1::Greater>(binary)) else if (auto greater_v1_node = as_type_ptr<op::v1::Greater>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
greater_v1_node->get_autob()); greater_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto greater_eq_v0_node = as_type_ptr<op::v0::GreaterEq>(binary)) else if (auto greater_eq_v0_node = as_type_ptr<op::v0::GreaterEq>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
greater_eq_v0_node->get_autob()); greater_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto greater_eq_v1_node = as_type_ptr<op::v1::GreaterEqual>(binary)) else if (auto greater_eq_v1_node = as_type_ptr<op::v1::GreaterEqual>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
greater_eq_v1_node->get_autob()); greater_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto less_v0_node = as_type_ptr<op::v0::Less>(binary)) else if (auto less_v0_node = as_type_ptr<op::v0::Less>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
less_v0_node->get_autob()); less_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto less_v1_node = as_type_ptr<op::v1::Less>(binary)) else if (auto less_v1_node = as_type_ptr<op::v1::Less>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
less_v1_node->get_autob()); less_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto less_eq_v0_node = as_type_ptr<op::v0::LessEq>(binary)) else if (auto less_eq_v0_node = as_type_ptr<op::v0::LessEq>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
less_eq_v0_node->get_autob()); less_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto less_eq_v1_node = as_type_ptr<op::v1::LessEqual>(binary)) else if (auto less_eq_v1_node = as_type_ptr<op::v1::LessEqual>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
less_eq_v1_node->get_autob()); less_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto not_equal_v0_node = as_type_ptr<op::v0::NotEqual>(binary)) else if (auto not_equal_v0_node = as_type_ptr<op::v0::NotEqual>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
not_equal_v0_node->get_autob()); not_equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto not_equal_v1_node = as_type_ptr<op::v1::NotEqual>(binary)) else if (auto not_equal_v1_node = as_type_ptr<op::v1::NotEqual>(binary))
{ {
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
not_equal_v1_node->get_autob()); not_equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else else
{ {
...@@ -319,21 +321,22 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -319,21 +321,22 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
shared_ptr<Node> binary, shared_ptr<Node> binary,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = binary->get_shape(); const Shape& out_shape = binary->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(Tout));
Tout* data_ptr = buffer.get_ptr<Tout>();
// NOTE: We will skip the executor if the shapes do not match, because that means // NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that. // auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape()) if (func != nullptr && a->get_shape() == b->get_shape())
{ {
vector<Tout> out_vec(shape_size(out_shape));
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr())); inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr())); inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(data_ptr);
func(inputs, outputs); func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_output_element_type(0), out_shape, data_ptr);
} }
else else
{ {
...@@ -341,178 +344,178 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -341,178 +344,178 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(), runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
add_v0_node->get_autob()); add_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto add_v1_node = as_type_ptr<op::v1::Add>(binary)) else if (auto add_v1_node = as_type_ptr<op::v1::Add>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(), runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
add_v1_node->get_autob()); add_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto divide_v0_node = as_type_ptr<op::v0::Divide>(binary)) else if (auto divide_v0_node = as_type_ptr<op::v0::Divide>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Divide> divop = as_type_ptr<op::v0::Divide>(binary); shared_ptr<op::v0::Divide> divop = as_type_ptr<op::v0::Divide>(binary);
bool pythondiv = divop->is_pythondiv(); bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(), runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
divide_v0_node->get_autob(), divide_v0_node->get_autob(),
pythondiv); pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto divide_v1_node = as_type_ptr<op::v1::Divide>(binary)) else if (auto divide_v1_node = as_type_ptr<op::v1::Divide>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v1::Divide> divop = as_type_ptr<op::v1::Divide>(binary); shared_ptr<op::v1::Divide> divop = as_type_ptr<op::v1::Divide>(binary);
bool pythondiv = divop->is_pythondiv(); bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(), runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
divide_v1_node->get_autob(), divide_v1_node->get_autob(),
pythondiv); pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto maximum_v0_node = as_type_ptr<op::v0::Maximum>(binary)) else if (auto maximum_v0_node = as_type_ptr<op::v0::Maximum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
maximum_v0_node->get_autob()); maximum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto maximum_v1_node = as_type_ptr<op::v1::Maximum>(binary)) else if (auto maximum_v1_node = as_type_ptr<op::v1::Maximum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
maximum_v1_node->get_autob()); maximum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto minimum_v0_node = as_type_ptr<op::v0::Minimum>(binary)) else if (auto minimum_v0_node = as_type_ptr<op::v0::Minimum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
minimum_v0_node->get_autob()); minimum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto minimum_v1_node = as_type_ptr<op::v1::Minimum>(binary)) else if (auto minimum_v1_node = as_type_ptr<op::v1::Minimum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
minimum_v1_node->get_autob()); minimum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto multiply_v0_node = as_type_ptr<op::v0::Multiply>(binary)) else if (auto multiply_v0_node = as_type_ptr<op::v0::Multiply>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(), runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
multiply_v0_node->get_autob()); multiply_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto multiply_v1_node = as_type_ptr<op::v1::Multiply>(binary)) else if (auto multiply_v1_node = as_type_ptr<op::v1::Multiply>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(), runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
multiply_v1_node->get_autob()); multiply_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto power_v0_node = as_type_ptr<op::v0::Power>(binary)) else if (auto power_v0_node = as_type_ptr<op::v0::Power>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Power> powop = as_type_ptr<op::v0::Power>(binary); shared_ptr<op::v0::Power> powop = as_type_ptr<op::v0::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(), runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
power_v0_node->get_autob()); power_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto power_v1_node = as_type_ptr<op::v1::Power>(binary)) else if (auto power_v1_node = as_type_ptr<op::v1::Power>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v1::Power> powop = as_type_ptr<op::v1::Power>(binary); shared_ptr<op::v1::Power> powop = as_type_ptr<op::v1::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(), runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
power_v1_node->get_autob()); power_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary)) else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(), runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), data_ptr,
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
subtract_node->get_autob()); subtract_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(
binary->get_output_element_type(0), out_shape, data_ptr);
} }
else else
{ {
......
...@@ -27,15 +27,16 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta ...@@ -27,15 +27,16 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
shared_ptr<Node> broadcast, shared_ptr<Node> broadcast,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = broadcast->get_shape(); const Shape& out_shape = broadcast->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (func != nullptr) if (func)
{ {
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr())); inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(data_ptr);
func(inputs, outputs); func(inputs, outputs);
} }
...@@ -45,7 +46,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta ...@@ -45,7 +46,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
if (static_bcast_axes.first) if (static_bcast_axes.first)
{ {
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(), runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
out_shape, out_shape,
static_bcast_axes.second); static_bcast_axes.second);
...@@ -58,7 +59,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta ...@@ -58,7 +59,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
else if (auto broadcast_v0 = as_type_ptr<op::v0::Broadcast>(broadcast)) else if (auto broadcast_v0 = as_type_ptr<op::v0::Broadcast>(broadcast))
{ {
runtime::reference::broadcast<T>(constant->get_data_ptr<T>(), runtime::reference::broadcast<T>(constant->get_data_ptr<T>(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
out_shape, out_shape,
broadcast_v0->get_broadcast_axes()); broadcast_v0->get_broadcast_axes());
...@@ -68,7 +69,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta ...@@ -68,7 +69,7 @@ shared_ptr<op::Constant> fold_constant_broadcast(shared_ptr<op::Constant> consta
throw ngraph_error("Unsupported op in broadcast constant folding."); throw ngraph_error("Unsupported op in broadcast constant folding.");
} }
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_broadcast() void pass::ConstantFolding::construct_constant_broadcast()
......
...@@ -36,16 +36,14 @@ static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op: ...@@ -36,16 +36,14 @@ static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op:
arg_shapes.push_back(input.get_shape()); arg_shapes.push_back(input.get_shape());
} }
std::vector<T> result_vec(shape_size(concat->get_shape())); runtime::AlignedBuffer buffer(shape_size(concat->get_shape()) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::concat<T>(arg_bufs, runtime::reference::concat<T>(
result_vec.data(), arg_bufs, data_ptr, arg_shapes, concat->get_shape(), concat->get_concatenation_axis());
arg_shapes,
concat->get_shape(),
concat->get_concatenation_axis());
return make_shared<op::Constant>( return make_shared<op::Constant>(
concat->get_output_element_type(0), concat->get_output_shape(0), result_vec); concat->get_output_element_type(0), concat->get_output_shape(0), data_ptr);
} }
void pass::ConstantFolding::construct_constant_concat() void pass::ConstantFolding::construct_constant_concat()
......
...@@ -28,13 +28,14 @@ template <typename TI, typename TO> ...@@ -28,13 +28,14 @@ template <typename TI, typename TO>
shared_ptr<op::Constant> fold_constant_convert_helper1(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_convert_helper1(shared_ptr<op::Constant> constant,
const element::Type& output_element_type) const element::Type& output_element_type)
{ {
auto out_shape = constant->get_shape(); const Shape& out_shape = constant->get_shape();
vector<TO> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(TO));
TO* data_ptr = buffer.get_ptr<TO>();
runtime::reference::convert<TI, TO>( runtime::reference::convert<TI, TO>(
constant->get_vector<TI>().data(), out_vec.data(), shape_size(out_shape)); constant->get_vector<TI>().data(), data_ptr, shape_size(out_shape));
return make_shared<op::Constant>(output_element_type, out_shape, out_vec); return make_shared<op::Constant>(output_element_type, out_shape, data_ptr);
} }
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++ // Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
......
...@@ -27,18 +27,19 @@ shared_ptr<op::Constant> fold_constant_dequantize(shared_ptr<op::Constant> const ...@@ -27,18 +27,19 @@ shared_ptr<op::Constant> fold_constant_dequantize(shared_ptr<op::Constant> const
shared_ptr<op::Constant> scale, shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset) shared_ptr<op::Constant> offset)
{ {
auto out_shape = constant->get_shape(); const Shape& out_shape = constant->get_shape();
vector<REAL> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(REAL));
REAL* data_ptr = buffer.get_ptr<REAL>();
runtime::reference::dequantize<QUANT, REAL>(constant->get_vector<QUANT>().data(), runtime::reference::dequantize<QUANT, REAL>(constant->get_vector<QUANT>().data(),
scale->get_vector<REAL>().data(), scale->get_vector<REAL>().data(),
offset->get_vector<QUANT>().data(), offset->get_vector<QUANT>().data(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
scale->get_shape(), scale->get_shape(),
dequant->get_axes()); dequant->get_axes());
return make_shared<op::Constant>(dequant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(dequant->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_dequantize() void pass::ConstantFolding::construct_constant_dequantize()
......
...@@ -27,16 +27,14 @@ shared_ptr<op::Constant> fold_constant_dyn_broadcast(shared_ptr<op::Constant> ar ...@@ -27,16 +27,14 @@ shared_ptr<op::Constant> fold_constant_dyn_broadcast(shared_ptr<op::Constant> ar
shared_ptr<op::Constant> shape, shared_ptr<op::Constant> shape,
shared_ptr<op::Constant> axes) shared_ptr<op::Constant> axes)
{ {
auto out_shape = shape->get_shape_val(); const Shape& out_shape = shape->get_shape_val();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::broadcast<T>(arg->get_data_ptr<T>(), runtime::reference::broadcast<T>(
out_vec.data(), arg->get_data_ptr<T>(), data_ptr, arg->get_shape(), out_shape, axes->get_axis_set_val());
arg->get_shape(),
out_shape,
axes->get_axis_set_val());
return make_shared<op::Constant>(arg->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(arg->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_dyn_broadcast() void pass::ConstantFolding::construct_constant_dyn_broadcast()
......
...@@ -28,20 +28,20 @@ template <class T> ...@@ -28,20 +28,20 @@ template <class T>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data, shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
shared_ptr<op::v1::Reshape> dyn_reshape) shared_ptr<op::v1::Reshape> dyn_reshape)
{ {
auto out_shape = dyn_reshape->get_shape(); const Shape& out_shape = dyn_reshape->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
AxisVector input_order(constant_data->get_shape().size()); AxisVector input_order(constant_data->get_shape().size());
std::iota(input_order.begin(), input_order.end(), 0); std::iota(input_order.begin(), input_order.end(), 0);
vector<T> out_vec(shape_size(out_shape));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(), runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
out_vec.data(), data_ptr,
constant_data->get_shape(), constant_data->get_shape(),
input_order, input_order,
out_shape); out_shape);
return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_dyn_reshape() void pass::ConstantFolding::construct_constant_dyn_reshape()
......
...@@ -42,31 +42,34 @@ shared_ptr<op::Constant> fold_constant_dyn_slice(shared_ptr<op::Constant> data, ...@@ -42,31 +42,34 @@ shared_ptr<op::Constant> fold_constant_dyn_slice(shared_ptr<op::Constant> data,
slice->get_shrink_axis(), slice->get_shrink_axis(),
slice->get_ellipsis_mask()); slice->get_ellipsis_mask());
vector<T> slice_out_vec(shape_size(plan.reshape_in_shape)); runtime::AlignedBuffer slice_out_buffer(shape_size(plan.reshape_in_shape) * sizeof(T));
T* slice_out_data = slice_out_buffer.get_ptr<T>();
runtime::reference::slice<T>(data->get_data_ptr<T>(), runtime::reference::slice<T>(data->get_data_ptr<T>(),
slice_out_vec.data(), slice_out_data,
data->get_shape(), data->get_shape(),
Coordinate(plan.begins.begin(), plan.begins.end()), Coordinate(plan.begins.begin(), plan.begins.end()),
Coordinate(plan.ends.begin(), plan.ends.end()), Coordinate(plan.ends.begin(), plan.ends.end()),
Strides(plan.strides.begin(), plan.strides.end()), Strides(plan.strides.begin(), plan.strides.end()),
plan.reshape_in_shape); plan.reshape_in_shape);
vector<T> reshape_out_vec(shape_size(plan.reshape_out_shape)); runtime::AlignedBuffer reshape_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
runtime::reference::reshape<T>(slice_out_vec.data(), T* reshape_out_data = reshape_out_buffer.get_ptr<T>();
reshape_out_vec.data(), runtime::reference::reshape<T>(slice_out_data,
reshape_out_data,
plan.reshape_in_shape, plan.reshape_in_shape,
get_default_order(plan.reshape_in_shape.size()), get_default_order(plan.reshape_in_shape.size()),
plan.reshape_out_shape); plan.reshape_out_shape);
vector<T> reverse_out_vec(shape_size(plan.reshape_out_shape)); runtime::AlignedBuffer reverse_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
runtime::reference::reverse<T>(reshape_out_vec.data(), T* reverse_out_data = reverse_out_buffer.get_ptr<T>();
reverse_out_vec.data(), runtime::reference::reverse<T>(reshape_out_data,
reverse_out_data,
plan.reshape_out_shape, plan.reshape_out_shape,
plan.reshape_out_shape, plan.reshape_out_shape,
plan.reverse_axes); plan.reverse_axes);
return make_shared<op::Constant>( return make_shared<op::Constant>(
data->get_element_type(), plan.reshape_out_shape, reverse_out_vec); data->get_element_type(), plan.reshape_out_shape, reverse_out_data);
} }
void pass::ConstantFolding::construct_constant_dyn_slice() void pass::ConstantFolding::construct_constant_dyn_slice()
......
...@@ -28,13 +28,14 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op: ...@@ -28,13 +28,14 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
const shared_ptr<op::Constant>& indices, const shared_ptr<op::Constant>& indices,
const shared_ptr<Node>& gather) const shared_ptr<Node>& gather)
{ {
std::vector<T> result_vec(shape_size(gather->get_shape())); runtime::AlignedBuffer buffer(shape_size(gather->get_shape()) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (auto gather_v1 = as_type_ptr<op::v1::Gather>(gather)) if (auto gather_v1 = as_type_ptr<op::v1::Gather>(gather))
{ {
runtime::reference::gather<T, U>(data->get_data_ptr<T>(), runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(), indices->get_data_ptr<U>(),
result_vec.data(), data_ptr,
data->get_shape(), data->get_shape(),
indices->get_shape(), indices->get_shape(),
gather_v1->get_shape(), gather_v1->get_shape(),
...@@ -44,7 +45,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op: ...@@ -44,7 +45,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
{ {
runtime::reference::gather<T, U>(data->get_data_ptr<T>(), runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(), indices->get_data_ptr<U>(),
result_vec.data(), data_ptr,
data->get_shape(), data->get_shape(),
indices->get_shape(), indices->get_shape(),
gather_v0->get_shape(), gather_v0->get_shape(),
...@@ -56,7 +57,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op: ...@@ -56,7 +57,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
} }
return make_shared<op::Constant>( return make_shared<op::Constant>(
gather->get_output_element_type(0), gather->get_output_shape(0), result_vec); gather->get_output_element_type(0), gather->get_output_shape(0), data_ptr);
} }
template <typename T> template <typename T>
......
...@@ -43,12 +43,13 @@ static Shape get_shape_no_keep_dims(const AxisSet& reduction_axes, const Shape& ...@@ -43,12 +43,13 @@ static Shape get_shape_no_keep_dims(const AxisSet& reduction_axes, const Shape&
static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node) shared_ptr<Node> reduction_node)
{ {
vector<char> out_vec(shape_size(reduction_node->get_shape())); runtime::AlignedBuffer buffer(shape_size(reduction_node->get_shape()) * sizeof(char));
char* data_ptr = buffer.get_ptr<char>();
if (auto all = as_type_ptr<::ngraph::op::All>(reduction_node)) if (auto all = as_type_ptr<::ngraph::op::All>(reduction_node))
{ {
runtime::reference::all(constant->get_vector<char>().data(), runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
reduction_node->get_shape(), reduction_node->get_shape(),
all->get_reduction_axes()); all->get_reduction_axes());
...@@ -56,7 +57,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C ...@@ -56,7 +57,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
else if (auto any = as_type_ptr<::ngraph::op::Any>(reduction_node)) else if (auto any = as_type_ptr<::ngraph::op::Any>(reduction_node))
{ {
runtime::reference::any(constant->get_vector<char>().data(), runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
reduction_node->get_shape(), reduction_node->get_shape(),
any->get_reduction_axes()); any->get_reduction_axes());
...@@ -67,7 +68,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C ...@@ -67,7 +68,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
const auto input_shape = reduce_and->get_input_shape(0); const auto input_shape = reduce_and->get_input_shape(0);
runtime::reference::all(constant->get_vector<char>().data(), runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
get_shape_no_keep_dims(reduction_axes, input_shape), get_shape_no_keep_dims(reduction_axes, input_shape),
reduction_axes); reduction_axes);
...@@ -78,7 +79,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C ...@@ -78,7 +79,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
const auto input_shape = reduce_or->get_input_shape(0); const auto input_shape = reduce_or->get_input_shape(0);
runtime::reference::any(constant->get_vector<char>().data(), runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(), data_ptr,
constant->get_output_shape(0), constant->get_output_shape(0),
get_shape_no_keep_dims(reduction_axes, input_shape), get_shape_no_keep_dims(reduction_axes, input_shape),
reduction_axes); reduction_axes);
...@@ -92,7 +93,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C ...@@ -92,7 +93,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
} }
return make_shared<op::Constant>( return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec); reduction_node->get_output_element_type(0), reduction_node->get_shape(), data_ptr);
} }
void pass::ConstantFolding::construct_constant_logical_reduction() void pass::ConstantFolding::construct_constant_logical_reduction()
......
...@@ -26,8 +26,9 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, ...@@ -26,8 +26,9 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad, shared_ptr<op::Pad> pad,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = pad->get_shape(); const Shape& out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
auto pad_value = std::static_pointer_cast<op::Constant>( auto pad_value = std::static_pointer_cast<op::Constant>(
pad->input(1).get_source_output().get_node_shared_ptr()); pad->input(1).get_source_output().get_node_shared_ptr());
...@@ -38,7 +39,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, ...@@ -38,7 +39,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
inputs.push_back(const_cast<void*>(pad_value->get_data_ptr())); inputs.push_back(const_cast<void*>(pad_value->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(data_ptr);
func(inputs, outputs); func(inputs, outputs);
} }
...@@ -46,7 +47,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, ...@@ -46,7 +47,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
{ {
runtime::reference::pad<T>(constant->get_data_ptr<T>(), runtime::reference::pad<T>(constant->get_data_ptr<T>(),
pad_value->get_data_ptr<T>(), pad_value->get_data_ptr<T>(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
out_shape, out_shape,
pad->get_padding_below(), pad->get_padding_below(),
...@@ -54,7 +55,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, ...@@ -54,7 +55,7 @@ shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
pad->get_pad_mode()); pad->get_pad_mode());
} }
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_pad() void pass::ConstantFolding::construct_constant_pad()
......
...@@ -27,19 +27,20 @@ shared_ptr<op::Constant> fold_constant_quantize(shared_ptr<op::Constant> constan ...@@ -27,19 +27,20 @@ shared_ptr<op::Constant> fold_constant_quantize(shared_ptr<op::Constant> constan
shared_ptr<op::Constant> scale, shared_ptr<op::Constant> scale,
shared_ptr<op::Constant> offset) shared_ptr<op::Constant> offset)
{ {
auto out_shape = constant->get_shape(); const Shape& out_shape = constant->get_shape();
vector<QUANT> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(QUANT));
QUANT* data_ptr = buffer.get_ptr<QUANT>();
runtime::reference::quantize<REAL, QUANT>(constant->get_vector<REAL>().data(), runtime::reference::quantize<REAL, QUANT>(constant->get_vector<REAL>().data(),
scale->get_vector<REAL>().data(), scale->get_vector<REAL>().data(),
offset->get_vector<QUANT>().data(), offset->get_vector<QUANT>().data(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
scale->get_shape(), scale->get_shape(),
quant->get_axes(), quant->get_axes(),
quant->get_round_mode()); quant->get_round_mode());
return make_shared<op::Constant>(quant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(quant->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_quantize() void pass::ConstantFolding::construct_constant_quantize()
......
...@@ -26,13 +26,12 @@ shared_ptr<op::Constant> fold_constant_range(shared_ptr<op::Constant> start, ...@@ -26,13 +26,12 @@ shared_ptr<op::Constant> fold_constant_range(shared_ptr<op::Constant> start,
shared_ptr<op::Constant> step, shared_ptr<op::Constant> step,
shared_ptr<op::Range> range) shared_ptr<op::Range> range)
{ {
vector<T> out_vec(shape_size(range->get_shape())); runtime::AlignedBuffer buffer(shape_size(range->get_shape()) * sizeof(T));
runtime::reference::range<T>(start->get_vector<T>().data(), T* data_ptr = buffer.get_ptr<T>();
step->get_vector<T>().data(), runtime::reference::range<T>(
range->get_shape(), start->get_vector<T>().data(), step->get_vector<T>().data(), range->get_shape(), data_ptr);
out_vec.data());
return make_shared<op::Constant>(range->get_element_type(), range->get_shape(), out_vec); return make_shared<op::Constant>(range->get_element_type(), range->get_shape(), data_ptr);
} }
void pass::ConstantFolding::construct_constant_range() void pass::ConstantFolding::construct_constant_range()
......
...@@ -26,28 +26,29 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant ...@@ -26,28 +26,29 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant
shared_ptr<op::Reshape> reshape, shared_ptr<op::Reshape> reshape,
NodeExecutorTy func) NodeExecutorTy func)
{ {
auto out_shape = reshape->get_shape(); const Shape& out_shape = reshape->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (func != nullptr) if (func != nullptr)
{ {
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr())); inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(data_ptr);
func(inputs, outputs); func(inputs, outputs);
} }
else else
{ {
runtime::reference::reshape<T>(constant->get_data_ptr<T>(), runtime::reference::reshape<T>(constant->get_data_ptr<T>(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
reshape->get_input_order(), reshape->get_input_order(),
out_shape); out_shape);
} }
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_reshape() void pass::ConstantFolding::construct_constant_reshape()
......
...@@ -25,13 +25,14 @@ template <typename T> ...@@ -25,13 +25,14 @@ template <typename T>
static shared_ptr<op::Constant> fold_constant_reverse_helper(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant> fold_constant_reverse_helper(shared_ptr<op::Constant> constant,
const AxisSet& reversed_axes) const AxisSet& reversed_axes)
{ {
auto out_shape = constant->get_shape(); const Shape& out_shape = constant->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::reverse<T>( runtime::reference::reverse<T>(
constant->get_vector<T>().data(), out_vec.data(), out_shape, out_shape, reversed_axes); constant->get_vector<T>().data(), data_ptr, out_shape, out_shape, reversed_axes);
return make_shared<op::Constant>(constant->get_output_element_type(0), out_shape, out_vec); return make_shared<op::Constant>(constant->get_output_element_type(0), out_shape, data_ptr);
} }
static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> constant,
......
...@@ -27,15 +27,16 @@ shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& se ...@@ -27,15 +27,16 @@ shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& se
const shared_ptr<op::Constant>& f, const shared_ptr<op::Constant>& f,
const shared_ptr<Node>& select) const shared_ptr<Node>& select)
{ {
auto out_shape = select->get_shape(); const Shape& out_shape = select->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
if (auto select_v0 = as_type_ptr<op::v0::Select>(select)) if (auto select_v0 = as_type_ptr<op::v0::Select>(select))
{ {
runtime::reference::select<T>(selection->get_data_ptr<char>(), runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(), t->get_data_ptr<T>(),
f->get_data_ptr<T>(), f->get_data_ptr<T>(),
out_vec.data(), data_ptr,
shape_size(out_shape)); shape_size(out_shape));
} }
else if (auto select_v1 = as_type_ptr<op::v1::Select>(select)) else if (auto select_v1 = as_type_ptr<op::v1::Select>(select))
...@@ -43,14 +44,14 @@ shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& se ...@@ -43,14 +44,14 @@ shared_ptr<op::Constant> fold_constant_select(const shared_ptr<op::Constant>& se
runtime::reference::select<T>(selection->get_data_ptr<char>(), runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(), t->get_data_ptr<T>(),
f->get_data_ptr<T>(), f->get_data_ptr<T>(),
out_vec.data(), data_ptr,
selection->get_shape(), selection->get_shape(),
t->get_shape(), t->get_shape(),
f->get_shape(), f->get_shape(),
select_v1->get_auto_broadcast()); select_v1->get_auto_broadcast());
} }
return make_shared<op::Constant>(select->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(select->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_select() void pass::ConstantFolding::construct_constant_select()
......
...@@ -25,18 +25,19 @@ template <class T> ...@@ -25,18 +25,19 @@ template <class T>
shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
shared_ptr<op::Slice> slice) shared_ptr<op::Slice> slice)
{ {
auto out_shape = slice->get_shape(); const Shape& out_shape = slice->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
T* data_ptr = buffer.get_ptr<T>();
runtime::reference::slice<T>(constant->get_data_ptr<T>(), runtime::reference::slice<T>(constant->get_data_ptr<T>(),
out_vec.data(), data_ptr,
constant->get_shape(), constant->get_shape(),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
slice->get_strides(), slice->get_strides(),
out_shape); out_shape);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
} }
void pass::ConstantFolding::construct_constant_slice() void pass::ConstantFolding::construct_constant_slice()
......
...@@ -24,10 +24,9 @@ template <class T> ...@@ -24,10 +24,9 @@ template <class T>
shared_ptr<op::Constant> fold_constant_squeeze(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_squeeze(shared_ptr<op::Constant> constant,
shared_ptr<op::Squeeze> squeeze) shared_ptr<op::Squeeze> squeeze)
{ {
auto out_shape = squeeze->get_shape(); const Shape& out_shape = squeeze->get_shape();
vector<T> out_vec(shape_size(out_shape)); return make_shared<op::Constant>(
out_vec = constant->get_vector<T>(); constant->get_element_type(), out_shape, constant->get_data_ptr());
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void pass::ConstantFolding::construct_constant_squeeze() void pass::ConstantFolding::construct_constant_squeeze()
......
...@@ -54,31 +54,31 @@ shared_ptr<op::Constant> fold_constant_strided_slice(shared_ptr<op::Constant> da ...@@ -54,31 +54,31 @@ shared_ptr<op::Constant> fold_constant_strided_slice(shared_ptr<op::Constant> da
convert_mask_to_axis_set(slice->get_shrink_axis_mask()), convert_mask_to_axis_set(slice->get_shrink_axis_mask()),
convert_mask_to_axis_set(slice->get_ellipsis_mask())); convert_mask_to_axis_set(slice->get_ellipsis_mask()));
vector<T> slice_out_vec(shape_size(plan.reshape_in_shape)); runtime::AlignedBuffer slice_out_buffer(shape_size(plan.reshape_in_shape) * sizeof(T));
runtime::reference::slice<T>(data->get_data_ptr<T>(), runtime::reference::slice<T>(data->get_data_ptr<T>(),
slice_out_vec.data(), slice_out_buffer.get_ptr<T>(),
data->get_shape(), data->get_shape(),
Coordinate(plan.begins.begin(), plan.begins.end()), Coordinate(plan.begins.begin(), plan.begins.end()),
Coordinate(plan.ends.begin(), plan.ends.end()), Coordinate(plan.ends.begin(), plan.ends.end()),
Strides(plan.strides.begin(), plan.strides.end()), Strides(plan.strides.begin(), plan.strides.end()),
plan.reshape_in_shape); plan.reshape_in_shape);
vector<T> reshape_out_vec(shape_size(plan.reshape_out_shape)); runtime::AlignedBuffer reshape_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
runtime::reference::reshape<T>(slice_out_vec.data(), runtime::reference::reshape<T>(slice_out_buffer.get_ptr<T>(),
reshape_out_vec.data(), reshape_out_buffer.get_ptr<T>(),
plan.reshape_in_shape, plan.reshape_in_shape,
get_default_order(plan.reshape_in_shape.size()), get_default_order(plan.reshape_in_shape.size()),
plan.reshape_out_shape); plan.reshape_out_shape);
vector<T> reverse_out_vec(shape_size(plan.reshape_out_shape)); runtime::AlignedBuffer reverse_out_buffer(shape_size(plan.reshape_out_shape) * sizeof(T));
runtime::reference::reverse<T>(reshape_out_vec.data(), runtime::reference::reverse<T>(reshape_out_buffer.get_ptr<T>(),
reverse_out_vec.data(), reverse_out_buffer.get_ptr<T>(),
plan.reshape_out_shape, plan.reshape_out_shape,
plan.reshape_out_shape, plan.reshape_out_shape,
plan.reverse_axes); plan.reverse_axes);
return make_shared<op::Constant>( return make_shared<op::Constant>(
data->get_element_type(), plan.reshape_out_shape, reverse_out_vec); data->get_element_type(), plan.reshape_out_shape, reverse_out_buffer.get_ptr<T>());
} }
void pass::ConstantFolding::construct_constant_strided_slice() void pass::ConstantFolding::construct_constant_strided_slice()
......
...@@ -26,18 +26,18 @@ shared_ptr<op::Constant> fold_constant_transpose(shared_ptr<op::Constant> consta ...@@ -26,18 +26,18 @@ shared_ptr<op::Constant> fold_constant_transpose(shared_ptr<op::Constant> consta
shared_ptr<op::Constant> constant_perm, shared_ptr<op::Constant> constant_perm,
shared_ptr<op::Transpose> transpose) shared_ptr<op::Transpose> transpose)
{ {
auto out_shape = transpose->get_shape(); const Shape& out_shape = transpose->get_shape();
auto input_order = constant_perm->get_axis_vector_val(); auto input_order = constant_perm->get_axis_vector_val();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(), runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
out_vec.data(), buffer.get_ptr<T>(),
constant_data->get_shape(), constant_data->get_shape(),
input_order, input_order,
out_shape); out_shape);
return make_shared<op::Constant>(transpose->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(transpose->get_element_type(), out_shape, buffer.get_ptr<T>());
} }
void pass::ConstantFolding::construct_constant_transpose() void pass::ConstantFolding::construct_constant_transpose()
......
...@@ -60,15 +60,15 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -60,15 +60,15 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
} }
} }
auto out_shape = unary->get_shape(); const Shape& out_shape = unary->get_shape();
vector<T> out_vec(shape_size(out_shape)); runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
if (func != nullptr) if (func != nullptr)
{ {
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(constant->get_data_ptr())); inputs.push_back(const_cast<void*>(constant->get_data_ptr()));
vector<void*> outputs; vector<void*> outputs;
outputs.push_back(out_vec.data()); outputs.push_back(buffer.get_ptr<T>());
func(inputs, outputs); func(inputs, outputs);
} }
...@@ -77,47 +77,47 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -77,47 +77,47 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
if (is_type<op::Abs>(unary)) if (is_type<op::Abs>(unary))
{ {
runtime::reference::abs<T>( runtime::reference::abs<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::Ceiling>(unary)) else if (is_type<op::Ceiling>(unary))
{ {
runtime::reference::ceiling<T>( runtime::reference::ceiling<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::Floor>(unary)) else if (is_type<op::Floor>(unary))
{ {
runtime::reference::floor<T>( runtime::reference::floor<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::v1::LogicalNot>(unary)) else if (is_type<op::v1::LogicalNot>(unary))
{ {
runtime::reference::logical_not<T>( runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::Negative>(unary)) else if (is_type<op::Negative>(unary))
{ {
runtime::reference::negate<T>( runtime::reference::negate<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::v0::Not>(unary)) else if (is_type<op::v0::Not>(unary))
{ {
runtime::reference::logical_not<T>( runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::Relu>(unary)) else if (is_type<op::Relu>(unary))
{ {
runtime::reference::relu<T>( runtime::reference::relu<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::Sign>(unary)) else if (is_type<op::Sign>(unary))
{ {
runtime::reference::sign<T>( runtime::reference::sign<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else if (is_type<op::Sqrt>(unary)) else if (is_type<op::Sqrt>(unary))
{ {
runtime::reference::sqrt<T>( runtime::reference::sqrt<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), buffer.get_ptr<T>(), shape_size(out_shape));
} }
else else
{ {
...@@ -125,7 +125,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -125,7 +125,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
} }
} }
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, buffer.get_ptr<T>());
} }
void pass::ConstantFolding::construct_constant_unary() void pass::ConstantFolding::construct_constant_unary()
......
...@@ -24,10 +24,9 @@ template <class T> ...@@ -24,10 +24,9 @@ template <class T>
shared_ptr<op::Constant> fold_constant_unsqueeze(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_unsqueeze(shared_ptr<op::Constant> constant,
shared_ptr<op::Unsqueeze> unsqueeze) shared_ptr<op::Unsqueeze> unsqueeze)
{ {
auto out_shape = unsqueeze->get_shape(); const Shape& out_shape = unsqueeze->get_shape();
vector<T> out_vec(shape_size(out_shape)); return make_shared<op::Constant>(
out_vec = constant->get_vector<T>(); constant->get_element_type(), out_shape, constant->get_data_ptr());
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void pass::ConstantFolding::construct_constant_unsqueeze() void pass::ConstantFolding::construct_constant_unsqueeze()
......
...@@ -32,7 +32,7 @@ bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node) ...@@ -32,7 +32,7 @@ bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
size_t size = shape_size(constant->get_shape()); size_t size = shape_size(constant->get_shape());
if (size > minimum_size_of_interest) if (size > minimum_size_of_interest)
{ {
if (constant->are_all_data_elements_bitwise_identical()) if (constant->get_all_data_elements_bitwise_identical())
{ {
auto scalar_constant = make_shared<op::Constant>( auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr()); constant->get_element_type(), Shape{}, constant->get_data_ptr());
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
// Allocator objects and the allocation interfaces are owned by the // Allocator objects and the allocation interfaces are owned by the
// creators of AlignedBuffers. They need to ensure that the lifetime of // creators of AlignedBuffers. They need to ensure that the lifetime of
// allocator exceeds the lifetime of this AlignedBuffer. // allocator exceeds the lifetime of this AlignedBuffer.
AlignedBuffer(size_t byte_size, size_t alignment, Allocator* allocator = nullptr); AlignedBuffer(size_t byte_size, size_t alignment = 64, Allocator* allocator = nullptr);
AlignedBuffer(); AlignedBuffer();
~AlignedBuffer(); ~AlignedBuffer();
...@@ -47,7 +47,25 @@ public: ...@@ -47,7 +47,25 @@ public:
size_t size() const { return m_byte_size; } size_t size() const { return m_byte_size; }
void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; } void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; }
void* get_ptr() const { return m_aligned_buffer; } void* get_ptr() { return m_aligned_buffer; }
const void* get_ptr() const { return m_aligned_buffer; }
template <typename T>
T* get_ptr()
{
return reinterpret_cast<T*>(m_aligned_buffer);
}
template <typename T>
const T* get_ptr() const
{
return reinterpret_cast<const T*>(m_aligned_buffer);
}
template <typename T>
explicit operator T*()
{
return get_ptr<T>();
}
private: private:
AlignedBuffer(const AlignedBuffer&) = delete; AlignedBuffer(const AlignedBuffer&) = delete;
AlignedBuffer& operator=(const AlignedBuffer&) = delete; AlignedBuffer& operator=(const AlignedBuffer&) = delete;
......
...@@ -3338,7 +3338,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3338,7 +3338,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Constant: case OP_TYPEID::Constant:
{ {
auto tmp = static_cast<const op::Constant*>(&n); auto tmp = static_cast<const op::Constant*>(&n);
if (tmp->are_all_data_elements_bitwise_identical() && shape_size(tmp->get_shape()) > 0) if (tmp->get_all_data_elements_bitwise_identical() && shape_size(tmp->get_shape()) > 0)
{ {
vector<string> vs; vector<string> vs;
vs.push_back(tmp->convert_value_to_string(0)); vs.push_back(tmp->convert_value_to_string(0));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment