Unverified Commit 95ce59ab authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Switch op dynamic casts that crept back in (#3709)

* Switch op dynamic casts that crept back in

* PR for debugging purposes (#3718)

* back out

* transpose

* next try

* Another is_type

* Another is_type

* next is_type

* Revert, is_type remainder

* verify

* v1 and Dyn mixup

* Work around clang++-3.9 bug
parent 9ae3f6be
...@@ -181,7 +181,7 @@ size_t op::v1::Gather::get_axis() const ...@@ -181,7 +181,7 @@ size_t op::v1::Gather::get_axis() const
{ {
int64_t axis = AXIS_NOT_SET_VALUE; int64_t axis = AXIS_NOT_SET_VALUE;
auto axes_input_node = input_value(AXIS).get_node_shared_ptr(); auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node)) if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{ {
axis = const_op->get_vector<int64_t>()[0]; axis = const_op->get_vector<int64_t>()[0];
} }
......
...@@ -117,7 +117,7 @@ namespace ngraph ...@@ -117,7 +117,7 @@ namespace ngraph
{ {
public: public:
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"DynReshape", 1}; static constexpr NodeTypeInfo type_info{"Reshape", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
Reshape() = default; Reshape() = default;
/// \brief Constructs a dynamic reshape operation. This operation does not perform /// \brief Constructs a dynamic reshape operation. This operation does not perform
......
...@@ -137,7 +137,7 @@ void op::v1::Reverse::validate_and_infer_types() ...@@ -137,7 +137,7 @@ void op::v1::Reverse::validate_and_infer_types()
if (rev_axes_node->is_constant()) if (rev_axes_node->is_constant())
{ {
const auto rev_axes_constant = dynamic_pointer_cast<op::Constant>(rev_axes_node); const auto rev_axes_constant = as_type_ptr<op::Constant>(rev_axes_node);
if (m_mode == Mode::INDEX) if (m_mode == Mode::INDEX)
{ {
......
...@@ -225,7 +225,7 @@ size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node, ...@@ -225,7 +225,7 @@ size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
k_element_type, k_element_type,
")."); ").");
const auto k_constant = dynamic_pointer_cast<op::Constant>(node); const auto k_constant = as_type_ptr<op::Constant>(node);
size_t k = 0; size_t k = 0;
......
...@@ -4079,8 +4079,7 @@ namespace ngraph ...@@ -4079,8 +4079,7 @@ namespace ngraph
get_goe_input_output(ngraph::descriptor::Output* output) get_goe_input_output(ngraph::descriptor::Output* output)
{ {
auto it = output; auto it = output;
while (auto goe = while (auto goe = as_type_ptr<ngraph::op::GetOutputElement>(it->get_node()))
std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(it->get_node()))
{ {
it = &goe->get_inputs().at(0).get_output(); it = &goe->get_inputs().at(0).get_output();
} }
...@@ -4154,7 +4153,7 @@ namespace ngraph ...@@ -4154,7 +4153,7 @@ namespace ngraph
loop_symbol_table.at(get_goe_input_output(&input.get_output()))); loop_symbol_table.at(get_goe_input_output(&input.get_output())));
} }
if (std::dynamic_pointer_cast<ngraph::op::Relu>(op_node)) if (as_type_ptr<ngraph::op::Relu>(op_node))
{ {
auto casted_zero = std::string("static_cast<") + auto casted_zero = std::string("static_cast<") +
op->get_element_type().c_type_string() + op->get_element_type().c_type_string() +
......
...@@ -80,8 +80,7 @@ namespace ngraph ...@@ -80,8 +80,7 @@ namespace ngraph
{ {
auto qc = static_cast<const OP*>(node); auto qc = static_cast<const OP*>(node);
std::vector<float> scale_val = {1.0f}; std::vector<float> scale_val = {1.0f};
auto scale_const_op = auto scale_const_op = as_type_ptr<ngraph::op::Constant>(qc->get_arguments()[index]);
std::dynamic_pointer_cast<ngraph::op::Constant>(qc->get_arguments()[index]);
if (scale_const_op != nullptr) if (scale_const_op != nullptr)
{ {
scale_val = scale_const_op->template get_vector<float>(); scale_val = scale_const_op->template get_vector<float>();
...@@ -543,8 +542,8 @@ namespace ngraph ...@@ -543,8 +542,8 @@ namespace ngraph
{ {
auto index = get_scale_index<OP>(); auto index = get_scale_index<OP>();
std::vector<T> scale_val = {0}; std::vector<T> scale_val = {0};
auto scale_const_op = std::dynamic_pointer_cast<ngraph::op::Constant>( auto scale_const_op =
node->get_arguments()[index]); as_type_ptr<ngraph::op::Constant>(node->get_arguments()[index]);
if (scale_const_op != nullptr) if (scale_const_op != nullptr)
{ {
scale_val = scale_const_op->template get_vector<T>(); scale_val = scale_const_op->template get_vector<T>();
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pass/constant_folding.hpp" #include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp" #include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
...@@ -78,23 +79,26 @@ runtime::dynamic::DynamicExecutable::DynamicExecutable(shared_ptr<Function> wrap ...@@ -78,23 +79,26 @@ runtime::dynamic::DynamicExecutable::DynamicExecutable(shared_ptr<Function> wrap
set_parameters_and_results(*wrapped_function); set_parameters_and_results(*wrapped_function);
} }
// Due to clang++-3.9 bugs, this needs to be a non-static separate function from
// count_dyn_nodes.
bool is_dynamic_op(const std::shared_ptr<Node>& op)
{
return is_type<op::Transpose>(op) || is_type<op::DynBroadcast>(op) ||
is_type<op::DynReplaceSlice>(op) || is_type<op::DynSlice>(op) ||
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op);
}
// Helper for a vile hack in DynamicExecutable::call. See body of that function for details. // Helper for a vile hack in DynamicExecutable::call. See body of that function for details.
static size_t count_dyn_nodes(const shared_ptr<ngraph::Function>& f) static size_t count_dyn_nodes(const shared_ptr<ngraph::Function>& f)
{ {
size_t count = 0; size_t count = 0;
for (auto op : f->get_ops()) for (auto op : f->get_ops())
{ {
if (std::dynamic_pointer_cast<op::Transpose>(op) || if (is_dynamic_op(op))
std::dynamic_pointer_cast<op::DynBroadcast>(op) ||
std::dynamic_pointer_cast<op::DynReplaceSlice>(op) ||
std::dynamic_pointer_cast<op::DynSlice>(op) ||
std::dynamic_pointer_cast<op::DynReshape>(op) ||
std::dynamic_pointer_cast<op::Range>(op))
{ {
count++; count++;
} }
} }
return count; return count;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment