Commit 718e2ef1 authored by Amy Zhuang's avatar Amy Zhuang Committed by Sang Ik Lee

Do not allow builder to access tensor_data map directly. (#2494)

parent f1c72364
...@@ -33,13 +33,12 @@ namespace ngraph ...@@ -33,13 +33,12 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ArgMax) void Builder::BUILDER_DECL(ngraph::op::ArgMax)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::ArgMax* argmax = static_cast<const ngraph::op::ArgMax*>(node); const ngraph::op::ArgMax* argmax = static_cast<const ngraph::op::ArgMax*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (out[0].get_element_type() != element::i64 && if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32) out[0].get_element_type() != element::i32)
{ {
......
...@@ -33,13 +33,12 @@ namespace ngraph ...@@ -33,13 +33,12 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::ArgMin) void Builder::BUILDER_DECL(ngraph::op::ArgMin)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::ArgMin* argmin = static_cast<const ngraph::op::ArgMin*>(node); const ngraph::op::ArgMin* argmin = static_cast<const ngraph::op::ArgMin*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg_tensor = tensor_data[args[0].get_name()]; auto& arg_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
if (out[0].get_element_type() != element::i64 && if (out[0].get_element_type() != element::i64 &&
out[0].get_element_type() != element::i32) out[0].get_element_type() != element::i32)
{ {
......
...@@ -33,13 +33,12 @@ namespace ngraph ...@@ -33,13 +33,12 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::EmbeddingLookup) void Builder::BUILDER_DECL(ngraph::op::EmbeddingLookup)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
if (out[0].get_element_type() != element::f32 && if (out[0].get_element_type() != element::f32 &&
out[0].get_element_type() != element::f64) out[0].get_element_type() != element::f64)
{ {
......
...@@ -39,7 +39,6 @@ namespace ngraph ...@@ -39,7 +39,6 @@ namespace ngraph
void Builder::BUILDER_DECL(ngraph::op::Dequantize) void Builder::BUILDER_DECL(ngraph::op::Dequantize)
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::Dequantize* dequantize = const ngraph::op::Dequantize* dequantize =
static_cast<const ngraph::op::Dequantize*>(node); static_cast<const ngraph::op::Dequantize*>(node);
...@@ -47,8 +46,9 @@ namespace ngraph ...@@ -47,8 +46,9 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0); auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0); auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
...@@ -106,10 +106,11 @@ namespace ngraph ...@@ -106,10 +106,11 @@ namespace ngraph
} }
else else
{ {
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto daxes = dequantize->get_axes(); auto daxes = dequantize->get_axes();
...@@ -306,16 +307,15 @@ namespace ngraph ...@@ -306,16 +307,15 @@ namespace ngraph
else else
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
const ngraph::op::Quantize* quantize = const ngraph::op::Quantize* quantize =
static_cast<const ngraph::op::Quantize*>(node); static_cast<const ngraph::op::Quantize*>(node);
CPUKernelFunctor functor; CPUKernelFunctor functor;
auto& arg0_tensor = tensor_data[args[0].get_name()]; auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& arg1_tensor = tensor_data[args[1].get_name()]; auto& arg1_tensor = external_function->get_tensor_data(args[1].get_name());
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = external_function->get_tensor_data(args[2].get_name());
auto& out_tensor = tensor_data[out[0].get_name()]; auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
......
...@@ -132,7 +132,6 @@ namespace ngraph ...@@ -132,7 +132,6 @@ namespace ngraph
static constexpr size_t s_memory_pool_alignment = 4096; static constexpr size_t s_memory_pool_alignment = 4096;
std::vector<CPUKernelFunctor>& get_functors() { return functors; } std::vector<CPUKernelFunctor>& get_functors() { return functors; }
std::unordered_map<std::string, void*>& get_tensor_data() { return tensor_data; }
void*& get_tensor_data(const std::string& name); void*& get_tensor_data(const std::string& name);
std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)>& std::function<void(CPURuntimeContext*, std::vector<void*>&, std::vector<void*>&)>&
get_executor() get_executor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment