Commit c9eef901 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

is_op (#2203)

parent 90aa7336
......@@ -590,8 +590,9 @@ bool ngraph::possibly_overwritten(Node* node)
{
for (const descriptor::Input* input : output.get_inputs())
{
if (auto op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node()))
if (input->get_node()->is_op())
{
auto op = std::static_pointer_cast<ngraph::op::Op>(input->get_node());
if (auto op_annotations = op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......
......@@ -140,6 +140,7 @@ namespace ngraph
bool is_parameter() const;
virtual bool is_output() const;
virtual bool is_constant() const;
virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; }
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
......
......@@ -38,6 +38,7 @@ namespace ngraph
return m_op_annotations;
}
virtual bool is_op() const override { return true; }
protected:
Op(const std::string& node_type, const NodeVector& arguments);
......
......@@ -52,8 +52,9 @@ bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>&
{
continue;
}
if (auto op = std::dynamic_pointer_cast<op::Op>(n))
if (n->is_op())
{
auto op = std::static_pointer_cast<op::Op>(n);
auto annotations = op->get_op_annotations();
// If an op is passed through, do not add it to the common function
// collection so that the emitter can decide to eliminate it if desired
......
......@@ -47,8 +47,9 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
std::set<const descriptor::Tensor*> reused_inputs;
if (auto op = std::dynamic_pointer_cast<op::Op>(node))
if (node->is_op())
{
auto op = std::static_pointer_cast<op::Op>(node);
// concat and slice in_place_oi should be treated differently
if (!std::dynamic_pointer_cast<op::Concat>(node) &&
!std::dynamic_pointer_cast<op::Slice>(node))
......
......@@ -28,8 +28,9 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
{
for (auto& node : function->get_ordered_ops())
{
if (auto op = std::dynamic_pointer_cast<op::Op>(node))
if (node->is_op())
{
auto op = std::static_pointer_cast<op::Op>(node);
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
auto op_annotations = op->get_op_annotations();
if (!op_annotations)
......@@ -55,8 +56,9 @@ bool ngraph::pass::PropagateCacheability::run_on_function(std::shared_ptr<Functi
for (auto arg : node->get_arguments())
{
NGRAPH_DEBUG << "propagate cacheability: arg is " << arg->get_name();
if (auto arg_op = std::dynamic_pointer_cast<op::Op>(arg))
if (arg->is_op())
{
auto arg_op = std::static_pointer_cast<op::Op>(arg);
auto arg_op_annotations = arg_op->get_op_annotations();
NGRAPH_ASSERT(arg_op_annotations);
if (!arg_op_annotations->is_cacheable())
......
......@@ -1135,12 +1135,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
stack.pop_front();
for (auto input : it->get_inputs())
{
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output() || dynamic_pointer_cast<ngraph::op::Slice>(c_op))
auto input_node = input->get_node();
if (!input_node->is_op() || input_node->is_output() ||
dynamic_pointer_cast<ngraph::op::Slice>(input_node))
{
continue;
}
auto c_op = std::static_pointer_cast<ngraph::op::Op>(input_node);
if (auto op_annotations = c_op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......@@ -1182,12 +1184,13 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_constant(
stack.pop_front();
for (auto input : it->get_inputs())
{
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output())
auto input_node = input->get_node();
if (!input_node->is_op() || input_node->is_output())
{
continue;
}
auto c_op = std::static_pointer_cast<ngraph::op::Op>(input_node);
if (auto op_annotations = c_op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......@@ -1229,11 +1232,14 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
do
{
propagate_further = false;
auto arg = std::dynamic_pointer_cast<ngraph::op::Op>(it->get_node());
if (!arg || std::dynamic_pointer_cast<ngraph::op::Slice>(it->get_node()))
auto it_node = it->get_node();
if (!it_node->is_op() || std::dynamic_pointer_cast<ngraph::op::Slice>(it_node))
{
break;
}
auto arg = std::static_pointer_cast<ngraph::op::Op>(it_node);
if (auto op_annotations = arg->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......@@ -1284,8 +1290,7 @@ void runtime::cpu::CPU_ExternalFunction::process_in_place_concat(
auto offset = output_tensor->get_pool_offset();
for (auto arg : concat->get_arguments())
{
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor();
auto input_tensor = &arg->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG << "cpu_external_function: change offset, old offset is "
......@@ -1349,8 +1354,7 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_concat(
auto offset = output_tensor->get_pool_offset();
for (auto arg : it->get_arguments())
{
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor();
auto input_tensor = &arg->get_output_tensor();
auto old_offset = input_tensor->get_pool_offset();
input_tensor->set_pool_offset(offset);
NGRAPH_DEBUG
......@@ -1383,8 +1387,7 @@ void runtime::cpu::CPU_ExternalFunction::process_in_place_slice(
auto input = &slice->get_inputs().at(0);
auto arg = input->get_output().get_node();
auto index = input->get_output().get_index();
auto input_node = std::dynamic_pointer_cast<ngraph::op::Op>(arg);
auto input_tensor = &input_node->get_output_tensor(index);
auto input_tensor = &arg->get_output_tensor(index);
if (m_tensor_roles[input_tensor->get_name()] == CPUTensorRole::INPUT)
{
NGRAPH_DEBUG << "cpu_external_function: function input pointer passed to "
......
......@@ -132,8 +132,9 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
if (!std::dynamic_pointer_cast<op::Concat>(arg))
{
if (auto op = std::dynamic_pointer_cast<op::Op>(arg))
if (arg->is_op())
{
auto op = std::static_pointer_cast<op::Op>(arg);
auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
......@@ -174,8 +175,9 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{
if ((user != concat))
{
if (auto op = std::dynamic_pointer_cast<op::Op>(user))
if (user->is_op())
{
auto op = std::static_pointer_cast<op::Op>(user);
if (auto op_annotations = op->get_op_annotations())
{
if (op_annotations->get_in_place_oi_pairs().size() > 0)
......
......@@ -765,12 +765,13 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_input(
stack.pop_front();
for (auto input : it->get_inputs())
{
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output())
auto input_node = input->get_node();
if (!input_node->is_op() || input_node->is_output())
{
continue;
}
auto c_op = std::static_pointer_cast<ngraph::op::Op>(input_node);
if (auto op_annotations = c_op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......@@ -804,11 +805,11 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_output(
do
{
propagate_further = false;
auto arg = std::dynamic_pointer_cast<ngraph::op::Op>(it->get_node());
if (!arg)
if (!it->get_node()->is_op())
{
break;
}
auto arg = std::static_pointer_cast<ngraph::op::Op>(it->get_node());
if (auto op_annotations = arg->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment