Commit 3ff9e490 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

DEX: Propagate inputs & outputs (#1406)

* Added MKLDNN concat in DEX

* Fix allocation size in batchnorm kernel

* added missing brackets

* Support in-place input & output propagation
parent 84719348
...@@ -35,13 +35,12 @@ namespace ngraph ...@@ -35,13 +35,12 @@ namespace ngraph
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node); auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto window_shape = avg_pool->get_window_shape(); auto window_shape = avg_pool->get_window_shape();
auto window_movement_strides = avg_pool->get_window_movement_strides(); auto window_movement_strides = avg_pool->get_window_movement_strides();
...@@ -112,13 +111,12 @@ namespace ngraph ...@@ -112,13 +111,12 @@ namespace ngraph
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node); auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto delta_shape = args[0].get_shape(); auto delta_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& delta_tensor = tensor_data[args[0].get_name()]; auto& delta_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto window_shape = apb->get_window_shape(); auto window_shape = apb->get_window_shape();
auto window_movement_strides = apb->get_window_movement_strides(); auto window_movement_strides = apb->get_window_movement_strides();
......
...@@ -41,12 +41,11 @@ namespace ngraph ...@@ -41,12 +41,11 @@ namespace ngraph
bool append_relu) bool append_relu)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
const OP* batchnorm = static_cast<const OP*>(node); const OP* batchnorm = static_cast<const OP*>(node);
...@@ -75,8 +74,8 @@ namespace ngraph ...@@ -75,8 +74,8 @@ namespace ngraph
if (batchnorm->get_training_flag() && args.size() == 3) if (batchnorm->get_training_flag() && args.size() == 3)
{ {
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& out2_tensor = tensor_data[out[2].get_name()]; auto& out2_tensor = external_function->get_tensor_data(out[2].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 2);
...@@ -117,8 +116,8 @@ namespace ngraph ...@@ -117,8 +116,8 @@ namespace ngraph
} }
else else
{ {
auto& arg3_tensor = tensor_data[args[3].get_name()]; auto& arg3_tensor = external_function->get_tensor_data(args[3].get_name());
auto& arg4_tensor = tensor_data[args[4].get_name()]; auto& arg4_tensor = external_function->get_tensor_data(args[4].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
...@@ -171,7 +170,6 @@ namespace ngraph ...@@ -171,7 +170,6 @@ namespace ngraph
if (batchnorm->get_training_flag() && args.size() == 3) if (batchnorm->get_training_flag() && args.size() == 3)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<decltype( std::function<decltype(
runtime::cpu::kernel::batch_norm_three_outputs<float>)> runtime::cpu::kernel::batch_norm_three_outputs<float>)>
...@@ -182,13 +180,13 @@ namespace ngraph ...@@ -182,13 +180,13 @@ namespace ngraph
runtime::cpu::kernel::batch_norm_three_outputs); runtime::cpu::kernel::batch_norm_three_outputs);
auto arg2_shape = args[2].get_shape(); auto arg2_shape = args[2].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& out2_tensor = tensor_data[out[2].get_name()]; auto& out2_tensor = external_function->get_tensor_data(out[2].get_name());
auto eps = batchnorm->get_eps_value(); auto eps = batchnorm->get_eps_value();
auto functor = [&, kernel, arg2_shape, eps](CPURuntimeContext* ctx) { auto functor = [&, kernel, arg2_shape, eps](CPURuntimeContext* ctx) {
...@@ -206,7 +204,6 @@ namespace ngraph ...@@ -206,7 +204,6 @@ namespace ngraph
else else
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<decltype(runtime::cpu::kernel::batch_norm_one_output<float>)> std::function<decltype(runtime::cpu::kernel::batch_norm_one_output<float>)>
kernel; kernel;
...@@ -216,13 +213,13 @@ namespace ngraph ...@@ -216,13 +213,13 @@ namespace ngraph
runtime::cpu::kernel::batch_norm_one_output); runtime::cpu::kernel::batch_norm_one_output);
auto arg2_shape = args[2].get_shape(); auto arg2_shape = args[2].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& arg3_tensor = tensor_data[args[3].get_name()]; auto& arg3_tensor = external_function->get_tensor_data(args[3].get_name());
auto& arg4_tensor = tensor_data[args[4].get_name()]; auto& arg4_tensor = external_function->get_tensor_data(args[4].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto eps = batchnorm->get_eps_value(); auto eps = batchnorm->get_eps_value();
auto functor = [&, kernel, arg2_shape, eps](CPURuntimeContext* ctx) { auto functor = [&, kernel, arg2_shape, eps](CPURuntimeContext* ctx) {
...@@ -252,18 +249,17 @@ namespace ngraph ...@@ -252,18 +249,17 @@ namespace ngraph
static_cast<const ngraph::op::BatchNormBackprop*>(node); static_cast<const ngraph::op::BatchNormBackprop*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg3_tensor = external_function->get_tensor_data(args[3].get_name());
auto& arg3_tensor = tensor_data[args[3].get_name()]; auto& arg4_tensor = external_function->get_tensor_data(args[4].get_name());
auto& arg4_tensor = tensor_data[args[4].get_name()]; auto& arg5_tensor = external_function->get_tensor_data(args[5].get_name());
auto& arg5_tensor = tensor_data[args[5].get_name()];
auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out2_tensor = external_function->get_tensor_data(out[2].get_name());
auto& out2_tensor = tensor_data[out[2].get_name()];
// Kill clang diagnostics bug // Kill clang diagnostics bug
#pragma clang diagnostic push #pragma clang diagnostic push
......
...@@ -33,10 +33,9 @@ namespace ngraph ...@@ -33,10 +33,9 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::BoundedRelu) void Builder::BUILDER_DECL(ngraph::op::BoundedRelu)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& input_tensor = tensor_data[args[0].get_name()]; auto& input_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
......
...@@ -36,7 +36,6 @@ namespace ngraph ...@@ -36,7 +36,6 @@ namespace ngraph
(static_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis(); (static_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
vector<reference_wrapper<void*>> arg_tensors; vector<reference_wrapper<void*>> arg_tensors;
vector<Shape> arg_shapes; vector<Shape> arg_shapes;
...@@ -44,12 +43,13 @@ namespace ngraph ...@@ -44,12 +43,13 @@ namespace ngraph
{ {
if (shape_size(arg.get_shape())) if (shape_size(arg.get_shape()))
{ {
arg_tensors.emplace_back(tensor_data[arg.get_name()]); arg_tensors.emplace_back(
external_function->get_tensor_data(arg.get_name()));
arg_shapes.emplace_back(arg.get_shape()); arg_shapes.emplace_back(arg.get_shape());
} }
} }
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
......
...@@ -31,10 +31,9 @@ namespace ngraph ...@@ -31,10 +31,9 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Convert) void Builder::BUILDER_DECL(ngraph::op::Convert)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto element_count = out[0].get_size(); auto element_count = out[0].get_size();
......
...@@ -32,10 +32,9 @@ namespace ngraph ...@@ -32,10 +32,9 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::runtime::cpu::op::ConvertLayout) void Builder::BUILDER_DECL(ngraph::runtime::cpu::op::ConvertLayout)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
......
...@@ -38,15 +38,14 @@ namespace ngraph ...@@ -38,15 +38,14 @@ namespace ngraph
auto convolution = static_cast<const ngraph::op::Convolution*>(node); auto convolution = static_cast<const ngraph::op::Convolution*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -113,11 +112,10 @@ namespace ngraph ...@@ -113,11 +112,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ConvolutionRelu) void Builder::BUILDER_DECL(ngraph::op::ConvolutionRelu)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -145,12 +143,11 @@ namespace ngraph ...@@ -145,12 +143,11 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ConvolutionBias) void Builder::BUILDER_DECL(ngraph::op::ConvolutionBias)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -179,12 +176,11 @@ namespace ngraph ...@@ -179,12 +176,11 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ConvolutionBiasAdd) void Builder::BUILDER_DECL(ngraph::op::ConvolutionBiasAdd)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -215,15 +211,14 @@ namespace ngraph ...@@ -215,15 +211,14 @@ namespace ngraph
auto convolution = static_cast<const ngraph::op::ConvolutionBackpropData*>(node); auto convolution = static_cast<const ngraph::op::ConvolutionBackpropData*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -296,15 +291,14 @@ namespace ngraph ...@@ -296,15 +291,14 @@ namespace ngraph
auto convolution = static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node); auto convolution = static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -375,12 +369,11 @@ namespace ngraph ...@@ -375,12 +369,11 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias) void Builder::BUILDER_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
...@@ -409,11 +402,10 @@ namespace ngraph ...@@ -409,11 +402,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::GroupConvolution) void Builder::BUILDER_DECL(ngraph::op::GroupConvolution)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto convolution = static_cast<const ngraph::op::GroupConvolution*>(node); auto convolution = static_cast<const ngraph::op::GroupConvolution*>(node);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <cstring> #include <cstring>
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
...@@ -34,15 +35,14 @@ namespace ngraph ...@@ -34,15 +35,14 @@ namespace ngraph
auto dot = static_cast<const ngraph::op::Dot*>(node); auto dot = static_cast<const ngraph::op::Dot*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto reduction_axes_count = dot->get_reduction_axes_count(); auto reduction_axes_count = dot->get_reduction_axes_count();
...@@ -68,8 +68,8 @@ namespace ngraph ...@@ -68,8 +68,8 @@ namespace ngraph
auto first = (arg0_shape.empty() ? args[0] : args[1]); auto first = (arg0_shape.empty() ? args[0] : args[1]);
auto second = (arg0_shape.empty() ? args[1] : args[0]); auto second = (arg0_shape.empty() ? args[1] : args[0]);
auto& first_tensor = tensor_data[first.get_name()]; auto& first_tensor = external_function->get_tensor_data(first.get_name());
auto& second_tensor = tensor_data[second.get_name()]; auto& second_tensor = external_function->get_tensor_data(second.get_name());
std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel; std::function<decltype(runtime::cpu::kernel::dot_scalar<float>)> kernel;
......
...@@ -36,7 +36,6 @@ namespace ngraph ...@@ -36,7 +36,6 @@ namespace ngraph
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& callees = external_function->get_callees(); auto& callees = external_function->get_callees();
// Note: We bypass the completely broken ngraph "backend" API here // Note: We bypass the completely broken ngraph "backend" API here
...@@ -48,14 +47,14 @@ namespace ngraph ...@@ -48,14 +47,14 @@ namespace ngraph
{ {
arg_shapes.emplace_back(arg.get_shape()); arg_shapes.emplace_back(arg.get_shape());
arg_types.emplace_back(arg.get_element_type()); arg_types.emplace_back(arg.get_element_type());
arg_tensors.emplace_back(tensor_data[arg.get_name()]); arg_tensors.emplace_back(external_function->get_tensor_data(arg.get_name()));
} }
for (const auto& result : out) for (const auto& result : out)
{ {
out_shapes.emplace_back(result.get_shape()); out_shapes.emplace_back(result.get_shape());
out_types.emplace_back(result.get_element_type()); out_types.emplace_back(result.get_element_type());
out_tensors.emplace_back(tensor_data[result.get_name()]); out_tensors.emplace_back(external_function->get_tensor_data(result.get_name()));
} }
if (!callees.count(function->get_name())) if (!callees.count(function->get_name()))
......
...@@ -33,13 +33,12 @@ namespace ngraph ...@@ -33,13 +33,12 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::LRN) void Builder::BUILDER_DECL(ngraph::op::LRN)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::LRN* lrn = static_cast<const ngraph::op::LRN*>(node); const ngraph::op::LRN* lrn = static_cast<const ngraph::op::LRN*>(node);
function<void(CPURuntimeContext*)> functor; function<void(CPURuntimeContext*)> functor;
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
......
...@@ -46,15 +46,14 @@ namespace ngraph ...@@ -46,15 +46,14 @@ namespace ngraph
"kernel"); "kernel");
} }
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& src_layer_tensor = tensor_data[args[0].get_name()]; auto& src_layer_tensor = external_function->get_tensor_data(args[0].get_name());
auto& src_iter_tensor = tensor_data[args[1].get_name()]; auto& src_iter_tensor = external_function->get_tensor_data(args[1].get_name());
auto& weights_layer_tensor = tensor_data[args[2].get_name()]; auto& weights_layer_tensor = external_function->get_tensor_data(args[2].get_name());
auto& weights_iter_tensor = tensor_data[args[3].get_name()]; auto& weights_iter_tensor = external_function->get_tensor_data(args[3].get_name());
auto& bias_tensor = tensor_data[args[4].get_name()]; auto& bias_tensor = external_function->get_tensor_data(args[4].get_name());
auto& dst_layer_tensor = tensor_data[out[0].get_name()]; auto& dst_layer_tensor = external_function->get_tensor_data(out[0].get_name());
auto& dst_iter_tensor = tensor_data[out[1].get_name()]; auto& dst_iter_tensor = external_function->get_tensor_data(out[1].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lstm_index = mkldnn_emitter->build_rnn<ngraph::op::Lstm>(node, args, out); auto lstm_index = mkldnn_emitter->build_rnn<ngraph::op::Lstm>(node, args, out);
......
...@@ -32,11 +32,10 @@ namespace ngraph ...@@ -32,11 +32,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::MatmulBias) void Builder::BUILDER_DECL(ngraph::op::MatmulBias)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node); const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node);
...@@ -91,7 +90,7 @@ namespace ngraph ...@@ -91,7 +90,7 @@ namespace ngraph
if (args.size() > 2) if (args.size() > 2)
{ {
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto axes = mm->get_broadcast_axes(); auto axes = mm->get_broadcast_axes();
if (axes.size() == 1) if (axes.size() == 1)
...@@ -316,11 +315,10 @@ namespace ngraph ...@@ -316,11 +315,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::BatchDot) void Builder::BUILDER_DECL(ngraph::op::BatchDot)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& mat_a = tensor_data[args[0].get_name()]; auto& mat_a = external_function->get_tensor_data(args[0].get_name());
auto& mat_b = tensor_data[args[1].get_name()]; auto& mat_b = external_function->get_tensor_data(args[1].get_name());
auto& mat_c = tensor_data[out[0].get_name()]; auto& mat_c = external_function->get_tensor_data(out[0].get_name());
const auto* cg = static_cast<const ngraph::op::BatchDot*>(node); const auto* cg = static_cast<const ngraph::op::BatchDot*>(node);
......
...@@ -36,13 +36,12 @@ namespace ngraph ...@@ -36,13 +36,12 @@ namespace ngraph
auto max_pool = static_cast<const ngraph::op::MaxPool*>(node); auto max_pool = static_cast<const ngraph::op::MaxPool*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto window_shape = max_pool->get_window_shape(); auto window_shape = max_pool->get_window_shape();
auto window_movement_strides = max_pool->get_window_movement_strides(); auto window_movement_strides = max_pool->get_window_movement_strides();
...@@ -106,15 +105,14 @@ namespace ngraph ...@@ -106,15 +105,14 @@ namespace ngraph
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node); auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg_fwd_shape = args[0].get_shape(); auto arg_fwd_shape = args[0].get_shape();
auto delta_shape = args[1].get_shape(); auto delta_shape = args[1].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& arg_fwd_tensor = tensor_data[args[0].get_name()]; auto& arg_fwd_tensor = external_function->get_tensor_data(args[0].get_name());
auto& delta_tensor = tensor_data[args[1].get_name()]; auto& delta_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto window_shape = mpb->get_window_shape(); auto window_shape = mpb->get_window_shape();
auto window_movement_strides = mpb->get_window_movement_strides(); auto window_movement_strides = mpb->get_window_movement_strides();
...@@ -198,11 +196,10 @@ namespace ngraph ...@@ -198,11 +196,10 @@ namespace ngraph
auto max_pool = static_cast<const ngraph::op::MaxPoolWithIndices*>(node); auto max_pool = static_cast<const ngraph::op::MaxPoolWithIndices*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto& out1_tensor = tensor_data[out[1].get_name()]; auto& out1_tensor = external_function->get_tensor_data(out[1].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = runtime::cpu::mkldnn_utils::get_input_mkldnn_md(node, 0); auto input_desc = runtime::cpu::mkldnn_utils::get_input_mkldnn_md(node, 0);
...@@ -237,11 +234,10 @@ namespace ngraph ...@@ -237,11 +234,10 @@ namespace ngraph
} }
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node); auto mpb = static_cast<const ngraph::op::MaxPoolWithIndicesBackprop*>(node);
......
...@@ -34,11 +34,10 @@ namespace ngraph ...@@ -34,11 +34,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Pad) void Builder::BUILDER_DECL(ngraph::op::Pad)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& padding_value = tensor_data[args[1].get_name()]; auto& padding_value = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto pad = static_cast<const ngraph::op::Pad*>(node); auto pad = static_cast<const ngraph::op::Pad*>(node);
......
...@@ -37,7 +37,6 @@ namespace ngraph ...@@ -37,7 +37,6 @@ namespace ngraph
auto function = reduce->get_functions()[0]; auto function = reduce->get_functions()[0];
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& callees = external_function->get_callees(); auto& callees = external_function->get_callees();
if (!callees.count(function->get_name())) if (!callees.count(function->get_name()))
...@@ -46,9 +45,9 @@ namespace ngraph ...@@ -46,9 +45,9 @@ namespace ngraph
} }
auto& reducer_external_function = callees[function->get_name()]; auto& reducer_external_function = callees[function->get_name()];
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
......
...@@ -36,7 +36,6 @@ namespace ngraph ...@@ -36,7 +36,6 @@ namespace ngraph
auto function = reduce_window->get_functions()[0]; auto function = reduce_window->get_functions()[0];
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& callees = external_function->get_callees(); auto& callees = external_function->get_callees();
if (!callees.count(function->get_name())) if (!callees.count(function->get_name()))
...@@ -45,9 +44,9 @@ namespace ngraph ...@@ -45,9 +44,9 @@ namespace ngraph
} }
auto& reducer_external_function = callees[function->get_name()]; auto& reducer_external_function = callees[function->get_name()];
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
......
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
#define BUILD_REDUCTION_FUNCTOR(OP, K) \ #define BUILD_REDUCTION_FUNCTOR(OP, K) \
auto& functors = external_function->get_functors(); \ auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
\ \
auto& arg_tensor = tensor_data[args[0].get_name()]; \ auto& arg_tensor = external_function->get_tensor_data(args[0].get_name()); \
auto& out_tensor = tensor_data[out[0].get_name()]; \ auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); \
\ \
auto op = static_cast<const ngraph::op::OP*>(node); \ auto op = static_cast<const ngraph::op::OP*>(node); \
\ \
......
...@@ -33,11 +33,10 @@ namespace ngraph ...@@ -33,11 +33,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ReluBackprop) void Builder::BUILDER_DECL(ngraph::op::ReluBackprop)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_fwd_tensor = tensor_data[args[0].get_name()]; auto& arg_fwd_tensor = external_function->get_tensor_data(args[0].get_name());
auto& delta_tensor = tensor_data[args[1].get_name()]; auto& delta_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
......
...@@ -33,12 +33,11 @@ namespace ngraph ...@@ -33,12 +33,11 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ReplaceSlice) void Builder::BUILDER_DECL(ngraph::op::ReplaceSlice)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto replace_slice = static_cast<const ngraph::op::ReplaceSlice*>(node); auto replace_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
......
...@@ -35,10 +35,9 @@ namespace ngraph ...@@ -35,10 +35,9 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Reshape) void Builder::BUILDER_DECL(ngraph::op::Reshape)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto reshape = static_cast<const ngraph::op::Reshape*>(node); auto reshape = static_cast<const ngraph::op::Reshape*>(node);
......
...@@ -33,10 +33,9 @@ namespace ngraph ...@@ -33,10 +33,9 @@ namespace ngraph
auto reverse = static_cast<const ngraph::op::Reverse*>(node); auto reverse = static_cast<const ngraph::op::Reverse*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
......
...@@ -33,11 +33,10 @@ namespace ngraph ...@@ -33,11 +33,10 @@ namespace ngraph
auto rev_seq = static_cast<const ngraph::op::ReverseSequence*>(node); auto rev_seq = static_cast<const ngraph::op::ReverseSequence*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& seq_len_tensor = tensor_data[args[1].get_name()]; auto& seq_len_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
......
...@@ -39,15 +39,14 @@ namespace ngraph ...@@ -39,15 +39,14 @@ namespace ngraph
} }
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& src_layer_tensor = tensor_data[args[0].get_name()]; auto& src_layer_tensor = external_function->get_tensor_data(args[0].get_name());
auto& src_iter_tensor = tensor_data[args[1].get_name()]; auto& src_iter_tensor = external_function->get_tensor_data(args[1].get_name());
auto& weights_layer_tensor = tensor_data[args[2].get_name()]; auto& weights_layer_tensor = external_function->get_tensor_data(args[2].get_name());
auto& weights_iter_tensor = tensor_data[args[3].get_name()]; auto& weights_iter_tensor = external_function->get_tensor_data(args[3].get_name());
auto& bias_tensor = tensor_data[args[4].get_name()]; auto& bias_tensor = external_function->get_tensor_data(args[4].get_name());
auto& dst_layer_tensor = tensor_data[out[0].get_name()]; auto& dst_layer_tensor = external_function->get_tensor_data(out[0].get_name());
auto& dst_iter_tensor = tensor_data[out[1].get_name()]; auto& dst_iter_tensor = external_function->get_tensor_data(out[1].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto rnn_index = mkldnn_emitter->build_rnn<ngraph::op::Rnn>(node, args, out); auto rnn_index = mkldnn_emitter->build_rnn<ngraph::op::Rnn>(node, args, out);
......
...@@ -31,13 +31,12 @@ namespace ngraph ...@@ -31,13 +31,12 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Select) void Builder::BUILDER_DECL(ngraph::op::Select)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto element_count = args[0].get_size(); auto element_count = args[0].get_size();
......
...@@ -39,7 +39,6 @@ namespace ngraph ...@@ -39,7 +39,6 @@ namespace ngraph
auto backend = runtime::Backend::create("CPU"); auto backend = runtime::Backend::create("CPU");
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& callees = external_function->get_callees(); auto& callees = external_function->get_callees();
// Note: We bypass the completely broken ngraph "backend" API here // Note: We bypass the completely broken ngraph "backend" API here
...@@ -53,13 +52,13 @@ namespace ngraph ...@@ -53,13 +52,13 @@ namespace ngraph
} }
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto window_shape = select_and_scatter->get_window_shape(); auto window_shape = select_and_scatter->get_window_shape();
auto window_movement_strides = select_and_scatter->get_window_movement_strides(); auto window_movement_strides = select_and_scatter->get_window_movement_strides();
......
...@@ -33,10 +33,9 @@ namespace ngraph ...@@ -33,10 +33,9 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Sigmoid) void Builder::BUILDER_DECL(ngraph::op::Sigmoid)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto input_shape = args[0].get_shape(); auto input_shape = args[0].get_shape();
auto out_shape = out[0].get_shape(); auto out_shape = out[0].get_shape();
...@@ -69,11 +68,10 @@ namespace ngraph ...@@ -69,11 +68,10 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::SigmoidBackprop) void Builder::BUILDER_DECL(ngraph::op::SigmoidBackprop)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto input_shape = args[0].get_shape(); auto input_shape = args[0].get_shape();
auto delta_shape = args[1].get_shape(); auto delta_shape = args[1].get_shape();
......
...@@ -35,10 +35,9 @@ namespace ngraph ...@@ -35,10 +35,9 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Slice) void Builder::BUILDER_DECL(ngraph::op::Slice)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
const ngraph::op::Slice* slice = static_cast<const ngraph::op::Slice*>(node); const ngraph::op::Slice* slice = static_cast<const ngraph::op::Slice*>(node);
......
...@@ -35,12 +35,11 @@ namespace ngraph ...@@ -35,12 +35,11 @@ namespace ngraph
auto softmax = static_cast<const ngraph::op::Softmax*>(node); auto softmax = static_cast<const ngraph::op::Softmax*>(node);
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto axes = softmax->get_axes(); auto axes = softmax->get_axes();
......
...@@ -175,12 +175,11 @@ namespace ngraph ...@@ -175,12 +175,11 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::And) void Builder::BUILDER_DECL(ngraph::op::And)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto element_count = out[0].get_size(); auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto functor = [&, element_count](CPURuntimeContext* ctx) { auto functor = [&, element_count](CPURuntimeContext* ctx) {
runtime::cpu::kernel::logical_and( runtime::cpu::kernel::logical_and(
...@@ -193,12 +192,11 @@ namespace ngraph ...@@ -193,12 +192,11 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Or) void Builder::BUILDER_DECL(ngraph::op::Or)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto element_count = out[0].get_size(); auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = external_function->get_tensor_data(out[0].get_name());
auto functor = [&, element_count](CPURuntimeContext* ctx) { auto functor = [&, element_count](CPURuntimeContext* ctx) {
runtime::cpu::kernel::logical_or( runtime::cpu::kernel::logical_or(
...@@ -348,17 +346,18 @@ namespace ngraph ...@@ -348,17 +346,18 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Constant) void Builder::BUILDER_DECL(ngraph::op::Constant)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
vector<void**> dest; vector<void**> dest;
for (auto& result : external_function->get_function()->get_results()) for (auto& result : external_function->get_function()->get_results())
{ {
if (result.get() == node) if (result.get() == node)
{ {
dest.push_back(&tensor_data[result->get_output_tensor(0).get_name()]); dest.push_back(&external_function->get_tensor_data(
result->get_output_tensor(0).get_name()));
} }
} }
auto& src = tensor_data[node->get_output_tensor(0).get_name()]; auto& src =
external_function->get_tensor_data(node->get_output_tensor(0).get_name());
auto size = node->get_output_tensor(0).size(); auto size = node->get_output_tensor(0).size();
auto functor = [&, dest, src, size](CPURuntimeContext* ctx) { auto functor = [&, dest, src, size](CPURuntimeContext* ctx) {
for (auto p : dest) for (auto p : dest)
......
...@@ -204,14 +204,13 @@ ...@@ -204,14 +204,13 @@
#define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \ #define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \ auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, size_t)> kernel; \ std::function<void(void*, void*, size_t)> kernel; \
\ \
SELECT_KERNEL(kernel, args[0].get_element_type(), OP); \ SELECT_KERNEL(kernel, args[0].get_element_type(), OP); \
\ \
auto element_count = out[0].get_size(); \ auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \ auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); \
auto& out0_tensor = tensor_data[out[0].get_name()]; \ auto& out0_tensor = external_function->get_tensor_data(out[0].get_name()); \
\ \
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \ auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, out0_tensor, element_count); \ kernel(arg0_tensor, out0_tensor, element_count); \
...@@ -220,15 +219,14 @@ ...@@ -220,15 +219,14 @@
#define BUILD_BINARY_ELEMWISE_FUNCTOR(OP) \ #define BUILD_BINARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \ auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, void*, size_t)> kernel; \ std::function<void(void*, void*, void*, size_t)> kernel; \
\ \
SELECT_KERNEL(kernel, args[0].get_element_type(), OP); \ SELECT_KERNEL(kernel, args[0].get_element_type(), OP); \
\ \
auto element_count = out[0].get_size(); \ auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \ auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); \
auto& arg1_tensor = tensor_data[args[1].get_name()]; \ auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name()); \
auto& out0_tensor = tensor_data[out[0].get_name()]; \ auto& out0_tensor = external_function->get_tensor_data(out[0].get_name()); \
\ \
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \ auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count); \ kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count); \
......
...@@ -654,7 +654,7 @@ using namespace ngraph::runtime; ...@@ -654,7 +654,7 @@ using namespace ngraph::runtime;
ss << "((" << type << "*)(inputs[" << arg_index << "]))"; ss << "((" << type << "*)(inputs[" << arg_index << "]))";
m_variable_name_map[tv->get_tensor().get_name()] = ss.str(); m_variable_name_map[tv->get_tensor().get_name()] = ss.str();
param_index_map[tv->get_tensor().get_name()] = arg_index; param_index_map[tv->get_tensor().get_name()] = arg_index;
propagate_in_place_input(&param->get_outputs().at(i), ss.str()); propagate_in_place_input(&param->get_outputs().at(i), ss.str(), false);
arg_index++; arg_index++;
} }
} }
...@@ -679,7 +679,8 @@ using namespace ngraph::runtime; ...@@ -679,7 +679,8 @@ using namespace ngraph::runtime;
auto output_name = ss.str(); auto output_name = ss.str();
m_variable_name_map[itv->get_tensor().get_name()] = ss.str(); m_variable_name_map[itv->get_tensor().get_name()] = ss.str();
propagate_in_place_output(&(res->get_inputs().at(0).get_output()), output_name); propagate_in_place_output(
&(res->get_inputs().at(0).get_output()), output_name, false);
} }
} }
...@@ -973,7 +974,7 @@ using namespace ngraph::runtime; ...@@ -973,7 +974,7 @@ using namespace ngraph::runtime;
} }
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input( void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
ngraph::descriptor::Output* output, std::string input_name) ngraph::descriptor::Output* output, std::string input_name, bool dex)
{ {
std::deque<ngraph::descriptor::Output*> stack; std::deque<ngraph::descriptor::Output*> stack;
stack.push_front(output); stack.push_front(output);
...@@ -999,7 +1000,15 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input( ...@@ -999,7 +1000,15 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
size_t output_index = oi_pair.output; size_t output_index = oi_pair.output;
auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor(); auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor();
m_variable_name_map[output_tensor.get_name()] = input_name; if (dex)
{
tensor_alias[output_tensor.get_name()] = input_name;
}
else
{
m_variable_name_map[output_tensor.get_name()] = input_name;
}
NGRAPH_DEBUG << "CPU codegen: Forwarding " << input_name << " through " NGRAPH_DEBUG << "CPU codegen: Forwarding " << input_name << " through "
<< output_tensor.get_name(); << output_tensor.get_name();
stack.push_back(&c_op->get_outputs().at(output_index)); stack.push_back(&c_op->get_outputs().at(output_index));
...@@ -1011,7 +1020,7 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input( ...@@ -1011,7 +1020,7 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
} }
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output( void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, std::string output_name) ngraph::descriptor::Output* res_src_output, std::string output_name, bool dex)
{ {
//we start with a particular output //we start with a particular output
//which is an argument to a given op::Result //which is an argument to a given op::Result
...@@ -1041,7 +1050,16 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output( ...@@ -1041,7 +1050,16 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
{ {
NGRAPH_DEBUG << "Reusing " << output_name << " for " NGRAPH_DEBUG << "Reusing " << output_name << " for "
<< input_tensor.get_name(); << input_tensor.get_name();
m_variable_name_map[input_tensor.get_name()] = output_name;
if (dex)
{
tensor_alias[input_tensor.get_name()] = output_name;
}
else
{
m_variable_name_map[input_tensor.get_name()] = output_name;
}
it = &arg->get_inputs().at(input_index).get_output(); it = &arg->get_inputs().at(input_index).get_output();
propagate_further = true; propagate_further = true;
} }
...@@ -1133,6 +1151,8 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1133,6 +1151,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
function_input_index.emplace_back(tensor_data[tv->get_tensor().get_name()], function_input_index.emplace_back(tensor_data[tv->get_tensor().get_name()],
arg_index, arg_index,
tensor_stale[tv->get_tensor().get_name()]); tensor_stale[tv->get_tensor().get_name()]);
propagate_in_place_input(
&param->get_outputs().at(i), tv->get_tensor().get_name(), true);
arg_index++; arg_index++;
} }
} }
...@@ -1150,6 +1170,8 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1150,6 +1170,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
shared_ptr<descriptor::TensorView> itv = shared_ptr<descriptor::TensorView> itv =
res->get_inputs().at(0).get_output().get_tensor_view(); res->get_inputs().at(0).get_output().get_tensor_view();
function_output_index.emplace_back(tensor_data[itv->get_tensor().get_name()], i); function_output_index.emplace_back(tensor_data[itv->get_tensor().get_name()], i);
propagate_in_place_output(
&(res->get_inputs().at(0).get_output()), tv->get_tensor().get_name(), true);
} }
} }
...@@ -1437,6 +1459,18 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1437,6 +1459,18 @@ void runtime::cpu::CPU_ExternalFunction::build()
} }
} }
void*& runtime::cpu::CPU_ExternalFunction::get_tensor_data(const std::string& name)
{
if (tensor_alias.count(name))
{
return tensor_data[tensor_alias[name]];
}
else
{
return tensor_data[name];
}
}
shared_ptr<ngraph::runtime::cpu::CPU_CallFrame> shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_ExternalFunction::make_call_frame() runtime::cpu::CPU_ExternalFunction::make_call_frame()
{ {
......
...@@ -101,6 +101,7 @@ namespace ngraph ...@@ -101,6 +101,7 @@ namespace ngraph
return functors; return functors;
} }
std::unordered_map<std::string, void*>& get_tensor_data() { return tensor_data; } std::unordered_map<std::string, void*>& get_tensor_data() { return tensor_data; }
void*& get_tensor_data(const std::string& name);
std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)>& std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)>&
get_executor() get_executor()
{ {
...@@ -120,11 +121,13 @@ namespace ngraph ...@@ -120,11 +121,13 @@ namespace ngraph
// For non-destructive passthrough kernels, propagate function // For non-destructive passthrough kernels, propagate function
// input buffers to internal ops // input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output, void propagate_in_place_input(ngraph::descriptor::Output* output,
std::string input_name); std::string input_name,
bool dex);
// For in-place kernels, propagate function output buffers to // For in-place kernels, propagate function output buffers to
// internal ops // internal ops
void propagate_in_place_output(ngraph::descriptor::Output* res_src_output, void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
std::string output_name); std::string output_name,
bool dex);
void emit_debug_function_entry(codegen::CodeWriter& writer, void emit_debug_function_entry(codegen::CodeWriter& writer,
Node* node, Node* node,
const std::vector<TensorViewWrapper>& in, const std::vector<TensorViewWrapper>& in,
...@@ -179,6 +182,7 @@ namespace ngraph ...@@ -179,6 +182,7 @@ namespace ngraph
executor; executor;
std::unordered_map<std::string, void*> tensor_data; std::unordered_map<std::string, void*> tensor_data;
std::unordered_map<std::string, bool> tensor_stale; std::unordered_map<std::string, bool> tensor_stale;
std::unordered_map<std::string, std::string> tensor_alias;
std::list<std::pair<std::reference_wrapper<void*>, size_t>> intermediates_offsets; std::list<std::pair<std::reference_wrapper<void*>, size_t>> intermediates_offsets;
std::list< std::list<
std::tuple<std::reference_wrapper<void*>, size_t, std::reference_wrapper<bool>>> std::tuple<std::reference_wrapper<void*>, size_t, std::reference_wrapper<bool>>>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment