Commit c9eef901 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

is_op (#2203)

parent 90aa7336
...@@ -590,8 +590,9 @@ bool ngraph::possibly_overwritten(Node* node) ...@@ -590,8 +590,9 @@ bool ngraph::possibly_overwritten(Node* node)
{ {
for (const descriptor::Input* input : output.get_inputs()) for (const descriptor::Input* input : output.get_inputs())
{ {
if (auto op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node())) if (input->get_node()->is_op())
{ {
auto op = std::static_pointer_cast<ngraph::op::Op>(input->get_node());
if (auto op_annotations = op->get_op_annotations()) if (auto op_annotations = op->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......
...@@ -140,6 +140,7 @@ namespace ngraph ...@@ -140,6 +140,7 @@ namespace ngraph
bool is_parameter() const; bool is_parameter() const;
virtual bool is_output() const; virtual bool is_output() const;
virtual bool is_constant() const; virtual bool is_constant() const;
virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
return m_op_annotations; return m_op_annotations;
} }
virtual bool is_op() const override { return true; }
protected: protected:
Op(const std::string& node_type, const NodeVector& arguments); Op(const std::string& node_type, const NodeVector& arguments);
......
...@@ -52,8 +52,9 @@ bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>& ...@@ -52,8 +52,9 @@ bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>&
{ {
continue; continue;
} }
if (auto op = std::dynamic_pointer_cast<op::Op>(n)) if (n->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(n);
auto annotations = op->get_op_annotations(); auto annotations = op->get_op_annotations();
// If an op is passed through, do not add it to the common function // If an op is passed through, do not add it to the common function
// collection so that the emitter can decide to eliminate it if desired // collection so that the emitter can decide to eliminate it if desired
......
...@@ -47,8 +47,9 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -47,8 +47,9 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs; std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
std::set<const descriptor::Tensor*> reused_inputs; std::set<const descriptor::Tensor*> reused_inputs;
if (auto op = std::dynamic_pointer_cast<op::Op>(node)) if (node->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(node);
// concat and slice in_place_oi should be treated differently // concat and slice in_place_oi should be treated differently
if (!std::dynamic_pointer_cast<op::Concat>(node) && if (!std::dynamic_pointer_cast<op::Concat>(node) &&
!std::dynamic_pointer_cast<op::Slice>(node)) !std::dynamic_pointer_cast<op::Slice>(node))
......
...@@ -28,8 +28,9 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi ...@@ -28,8 +28,9 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
{ {
for (auto& node : function->get_ordered_ops()) for (auto& node : function->get_ordered_ops())
{ {
if (auto op = std::dynamic_pointer_cast<op::Op>(node)) if (node->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(node);
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name(); NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
auto op_annotations = op->get_op_annotations(); auto op_annotations = op->get_op_annotations();
if (!op_annotations) if (!op_annotations)
...@@ -55,8 +56,9 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi ...@@ -55,8 +56,9 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name(); NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name();
if (auto arg_op = std::dynamic_pointer_cast<op::Op>(arg)) if (arg->is_op())
{ {
auto arg_op = std::static_pointer_cast<op::Op>(arg);
auto arg_op_annotations = arg_op->get_op_annotations(); auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_ASSERT(arg_op_annotations); NGRAPH_ASSERT(arg_op_annotations);
if (!arg_op_annotations->is_cacheable()) if (!arg_op_annotations->is_cacheable())
......
...@@ -1135,12 +1135,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input( ...@@ -1135,12 +1135,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
stack.pop_front(); stack.pop_front();
for (auto input : it->get_inputs()) for (auto input : it->get_inputs())
{ {
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()); auto input_node = input->get_node();
if (!c_op || c_op->is_output() || dynamic_pointer_cast<ngraph::op::Slice>(c_op)) if (!input_node->is_op() || input_node->is_output() ||
dynamic_pointer_cast<ngraph::op::Slice>(input_node))
{ {
continue; continue;
} }
auto c_op = std::static_pointer_cast<ngraph::op::Op>(input_node);
if (auto op_annotations = c_op->get_op_annotations()) if (auto op_annotations = c_op->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
...@@ -1182,12 +1184,13 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_constant( ...@@ -1182,12 +1184,13 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_constant(
stack.pop_front(); stack.pop_front();
for (auto input : it->get_inputs()) for (auto input : it->get_inputs())
{ {
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()); auto input_node = input->get_node();
if (!c_op || c_op->is_output()) if (!input_node->is_op() || input_node->is_output())
{ {
continue; continue;
} }
auto c_op = std::static_pointer_cast<ngraph::op::Op>(input_node);
if (auto op_annotations = c_op->get_op_annotations()) if (auto op_annotations = c_op->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
...@@ -1229,11 +1232,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output( ...@@ -1229,11 +1232,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
do do
{ {
propagate_further = false; propagate_further = false;
auto arg = std::dynamic_pointer_cast<ngraph::op::Op>(it->get_node()); auto it_node = it->get_node();
if (!arg || std::dynamic_pointer_cast<ngraph::op::Slice>(it->get_node()))
if (!it_node->is_op() || std::dynamic_pointer_cast<ngraph::op::Slice>(it_node))
{ {
break; break;
} }
auto arg = std::static_pointer_cast<ngraph::op::Op>(it_node);
if (auto op_annotations = arg->get_op_annotations()) if (auto op_annotations = arg->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
...@@ -1284,8 +1290,7 @@ void runtime::cpu::CPU_ExternalFunction::process_in_place_concat( ...@@ -1284,8 +1290,7 @@ void runtime::cpu::CPU_ExternalFunction::process_in_place_concat(
auto offset = output_tensor->get_pool_offset(); auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments()) for (auto arg : concat->get_arguments())
{ {
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg); auto input_tensor = &arg->get_output_tensor();
auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset(); auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset); input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "cpu_external_function: change offset, old offset is " NGRAPH_DEBUG << "cpu_external_function: change offset, old offset is "
...@@ -1349,8 +1354,7 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_concat( ...@@ -1349,8 +1354,7 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_concat(
auto offset = output_tensor->get_pool_offset(); auto offset = output_tensor->get_pool_offset();
for (auto arg : it->get_arguments()) for (auto arg : it->get_arguments())
{ {
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg); auto input_tensor = &arg->get_output_tensor();
auto input_tensor = &input_node->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset(); auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset); input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG NGRAPH_DEBUG
...@@ -1383,8 +1387,7 @@ void runtime::cpu::CPU_ExternalFunction::process_in_place_slice( ...@@ -1383,8 +1387,7 @@ void runtime::cpu::CPU_ExternalFunction::process_in_place_slice(
auto input = &slice->get_inputs().at(0); auto input = &slice->get_inputs().at(0);
auto arg = input->get_output().get_node(); auto arg = input->get_output().get_node();
auto index = input->get_output().get_index(); auto index = input->get_output().get_index();
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg); auto input_tensor = &arg->get_output_tensor(index);
auto input_tensor = &input_node->get_output_tensor(index);
if (m_tensor_roles[input_tensor->get_name()] == CPUTensorRole::INPUT) if (m_tensor_roles[input_tensor->get_name()] == CPUTensorRole::INPUT)
{ {
NGRAPH_DEBUG << "cpu_external_function: function input pointer passed to " NGRAPH_DEBUG << "cpu_external_function: function input pointer passed to "
......
...@@ -132,8 +132,9 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -132,8 +132,9 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
if (!std::dynamic_pointer_cast<op::Concat>(arg)) if (!std::dynamic_pointer_cast<op::Concat>(arg))
{ {
if (auto op = std::dynamic_pointer_cast<op::Op>(arg)) if (arg->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(arg);
auto annotation = op->get_op_annotations(); auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0) if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
...@@ -174,8 +175,9 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -174,8 +175,9 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{ {
if ((user != concat)) if ((user != concat))
{ {
if (auto op = std::dynamic_pointer_cast<op::Op>(user)) if (user->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(user);
if (auto op_annotations = op->get_op_annotations()) if (auto op_annotations = op->get_op_annotations())
{ {
if (op_annotations->get_in_place_oi_pairs().size() > 0) if (op_annotations->get_in_place_oi_pairs().size() > 0)
......
...@@ -765,12 +765,13 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_input( ...@@ -765,12 +765,13 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_input(
stack.pop_front(); stack.pop_front();
for (auto input : it->get_inputs()) for (auto input : it->get_inputs())
{ {
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()); auto input_node = input->get_node();
if (!c_op || c_op->is_output()) if (!input_node->is_op() || input_node->is_output())
{ {
continue; continue;
} }
auto c_op = std::static_pointer_cast<ngraph::op::Op>(input_node);
if (auto op_annotations = c_op->get_op_annotations()) if (auto op_annotations = c_op->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
...@@ -804,11 +805,11 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_output( ...@@ -804,11 +805,11 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_output(
do do
{ {
propagate_further = false; propagate_further = false;
auto arg = std::dynamic_pointer_cast<ngraph::op::Op>(it->get_node()); if (!it->get_node()->is_op())
if (!arg)
{ {
break; break;
} }
auto arg = std::static_pointer_cast<ngraph::op::Op>(it->get_node());
if (auto op_annotations = arg->get_op_annotations()) if (auto op_annotations = arg->get_op_annotations())
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment