Commit 6e234d65 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Ayzhuang/propagate cacheability (#1982)

* Add cacheablility propagation pass.

* Use a functor to create op annotations.

* Address PR feedback.

* Address PR feedback.

* Address PR feedback.
parent f0c17477
...@@ -140,6 +140,7 @@ set (SRC ...@@ -140,6 +140,7 @@ set (SRC
pass/memory_visualize.cpp pass/memory_visualize.cpp
pass/nop_elimination.cpp pass/nop_elimination.cpp
pass/pass.cpp pass/pass.cpp
pass/propagate_cacheability.cpp
pass/reshape_elimination.cpp pass/reshape_elimination.cpp
pass/zero_dim_tensor_elimination.cpp pass/zero_dim_tensor_elimination.cpp
pass/validate_graph.cpp pass/validate_graph.cpp
......
...@@ -54,9 +54,13 @@ namespace ngraph ...@@ -54,9 +54,13 @@ namespace ngraph
return m_in_place_oi_pairs; return m_in_place_oi_pairs;
} }
bool is_cacheable() const { return m_cacheable; }
void set_cacheable(bool val) { m_cacheable = val; }
private: private:
// map of output-input pairs for which in-place computation is valid // map of output-input pairs for which in-place computation is valid
std::vector<struct oi_pair> m_in_place_oi_pairs; std::vector<struct oi_pair> m_in_place_oi_pairs;
bool m_cacheable = false;
}; };
} }
} }
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
using namespace ngraph;
bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Function> function)
{
for (auto& node : function->get_ordered_ops())
{
if (auto op = std::dynamic_pointer_cast<op::Op>(node))
{
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
auto op_annotations = op->get_op_annotations();
if (!op_annotations)
{
NGRAPH_DEBUG << "propagate cacheability: create op_annotations";
op_annotations = op_annotations_factory();
op->set_op_annotations(op_annotations);
}
if (std::dynamic_pointer_cast<op::Constant>(node))
{
op_annotations->set_cacheable(true);
NGRAPH_DEBUG << "propagate cacheability: cacheability is 1";
}
else if (auto parameter = std::dynamic_pointer_cast<op::Parameter>(node))
{
op_annotations->set_cacheable(parameter->get_cacheable());
NGRAPH_DEBUG << "propagate cacheability: cacheability is "
<< parameter->get_cacheable();
}
else
{
bool cacheable = true;
for (auto arg : node->get_arguments())
{
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name();
if (auto arg_op = std::dynamic_pointer_cast<op::Op>(arg))
{
auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_ASSERT(arg_op_annotations);
if (!arg_op_annotations->is_cacheable())
{
cacheable = false;
break;
}
}
}
NGRAPH_DEBUG << "propagate cacheability: cacheability is " << cacheable;
op_annotations->set_cacheable(cacheable);
}
}
}
return false;
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class PropagateCacheability;
}
}
class ngraph::pass::PropagateCacheability : public FunctionPass
{
public:
PropagateCacheability()
: FunctionPass()
{
}
PropagateCacheability(
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)> func)
: FunctionPass()
, op_annotations_factory(func)
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)> op_annotations_factory =
[]() -> std::shared_ptr<ngraph::op::util::OpAnnotations> {
auto op_annotations = std::make_shared<ngraph::op::util::OpAnnotations>();
return op_annotations;
};
};
...@@ -24,6 +24,7 @@ set(SRC ...@@ -24,6 +24,7 @@ set(SRC
cpu_external_function.cpp cpu_external_function.cpp
cpu_kernels.cpp cpu_kernels.cpp
cpu_layout_descriptor.cpp cpu_layout_descriptor.cpp
cpu_op_annotations.cpp
cpu_tensor_view_wrapper.cpp cpu_tensor_view_wrapper.cpp
cpu_tensor_view.cpp cpu_tensor_view.cpp
cpu_tracing.cpp cpu_tracing.cpp
......
...@@ -128,6 +128,7 @@ ...@@ -128,6 +128,7 @@
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp" #include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
...@@ -137,6 +138,7 @@ ...@@ -137,6 +138,7 @@
#include "ngraph/runtime/cpu/cpu_emitter.hpp" #include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp" #include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_tracing.hpp" #include "ngraph/runtime/cpu/cpu_tracing.hpp"
#include "ngraph/runtime/cpu/cpu_visualize_tree.hpp" #include "ngraph/runtime/cpu/cpu_visualize_tree.hpp"
...@@ -409,6 +411,8 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -409,6 +411,8 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>( pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, node_function_map, common_function_string); femitter, node_function_map, common_function_string);
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::PropagateCacheability>(
runtime::cpu::get_annotations_factory());
pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true); pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
...@@ -1380,6 +1384,8 @@ void runtime::cpu::CPU_ExternalFunction::build() ...@@ -1380,6 +1384,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager); register_common_passes(pass_manager);
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::PropagateCacheability>(
runtime::cpu::get_annotations_factory());
pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true); pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true);
pass_manager.run_passes(m_function, false); pass_manager.run_passes(m_function, false);
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/op/util/op_annotations.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)>
get_annotations_factory()
{
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)> func =
[]() -> std::shared_ptr<ngraph::op::util::OpAnnotations> {
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
return op_annotations;
};
return func;
}
}
}
}
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#pragma once #pragma once
#include <functional>
#include <memory>
#include "ngraph/op/util/op_annotations.hpp" #include "ngraph/op/util/op_annotations.hpp"
namespace ngraph namespace ngraph
...@@ -34,6 +37,9 @@ namespace ngraph ...@@ -34,6 +37,9 @@ namespace ngraph
private: private:
bool m_mkldnn_op = false; bool m_mkldnn_op = false;
}; };
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)>
get_annotations_factory();
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment