Unverified Commit 95ce59ab authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Switch op dynamic casts that crept back in (#3709)

* Switch op dynamic casts that crept back in

* PR for debugging purposes (#3718)

* back out

* transpose

* next try

* Another is_type

* Another is_type

* next is_type

* Revert, is_type remainder

* verify

* v1 and Dyn mixup

* Work around clang++-3.9 bug
parent 9ae3f6be
......@@ -181,7 +181,7 @@ size_t op::v1::Gather::get_axis() const
{
int64_t axis = AXIS_NOT_SET_VALUE;
auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node))
if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{
axis = const_op->get_vector<int64_t>()[0];
}
......
......@@ -117,7 +117,7 @@ namespace ngraph
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"DynReshape", 1};
static constexpr NodeTypeInfo type_info{"Reshape", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reshape() = default;
/// \brief Constructs a dynamic reshape operation. This operation does not perform
......
......@@ -137,7 +137,7 @@ void op::v1::Reverse::validate_and_infer_types()
if (rev_axes_node->is_constant())
{
const auto rev_axes_constant = dynamic_pointer_cast<op::Constant>(rev_axes_node);
const auto rev_axes_constant = as_type_ptr<op::Constant>(rev_axes_node);
if (m_mode == Mode::INDEX)
{
......
......@@ -225,7 +225,7 @@ size_t op::v1::TopK::read_k_from_constant_node(const shared_ptr<Node>& node,
k_element_type,
").");
const auto k_constant = dynamic_pointer_cast<op::Constant>(node);
const auto k_constant = as_type_ptr<op::Constant>(node);
size_t k = 0;
......
......@@ -4079,8 +4079,7 @@ namespace ngraph
get_goe_input_output(ngraph::descriptor::Output* output)
{
auto it = output;
while (auto goe =
std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(it->get_node()))
while (auto goe = as_type_ptr<ngraph::op::GetOutputElement>(it->get_node()))
{
it = &goe->get_inputs().at(0).get_output();
}
......@@ -4154,7 +4153,7 @@ namespace ngraph
loop_symbol_table.at(get_goe_input_output(&input.get_output())));
}
if (std::dynamic_pointer_cast<ngraph::op::Relu>(op_node))
if (as_type_ptr<ngraph::op::Relu>(op_node))
{
auto casted_zero = std::string("static_cast<") +
op->get_element_type().c_type_string() +
......
......@@ -80,8 +80,7 @@ namespace ngraph
{
auto qc = static_cast<const OP*>(node);
std::vector<float> scale_val = {1.0f};
auto scale_const_op =
std::dynamic_pointer_cast<ngraph::op::Constant>(qc->get_arguments()[index]);
auto scale_const_op = as_type_ptr<ngraph::op::Constant>(qc->get_arguments()[index]);
if (scale_const_op != nullptr)
{
scale_val = scale_const_op->template get_vector<float>();
......@@ -543,8 +542,8 @@ namespace ngraph
{
auto index = get_scale_index<OP>();
std::vector<T> scale_val = {0};
auto scale_const_op = std::dynamic_pointer_cast<ngraph::op::Constant>(
node->get_arguments()[index]);
auto scale_const_op =
as_type_ptr<ngraph::op::Constant>(node->get_arguments()[index]);
if (scale_const_op != nullptr)
{
scale_val = scale_const_op->template get_vector<T>();
......
......@@ -22,6 +22,7 @@
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp"
......@@ -78,23 +79,26 @@ runtime::dynamic::DynamicExecutable::DynamicExecutable(shared_ptr<Function> wrap
set_parameters_and_results(*wrapped_function);
}
// Due to clang++-3.9 bugs, this needs to be a non-static separate function from
// count_dyn_nodes.
bool is_dynamic_op(const std::shared_ptr<Node>& op)
{
return is_type<op::Transpose>(op) || is_type<op::DynBroadcast>(op) ||
is_type<op::DynReplaceSlice>(op) || is_type<op::DynSlice>(op) ||
is_type<op::v1::Reshape>(op) || is_type<op::DynReshape>(op) || is_type<op::Range>(op);
}
// Helper for a vile hack in DynamicExecutable::call. See body of that function for details.
static size_t count_dyn_nodes(const shared_ptr<ngraph::Function>& f)
{
size_t count = 0;
for (auto op : f->get_ops())
{
if (std::dynamic_pointer_cast<op::Transpose>(op) ||
std::dynamic_pointer_cast<op::DynBroadcast>(op) ||
std::dynamic_pointer_cast<op::DynReplaceSlice>(op) ||
std::dynamic_pointer_cast<op::DynSlice>(op) ||
std::dynamic_pointer_cast<op::DynReshape>(op) ||
std::dynamic_pointer_cast<op::Range>(op))
if (is_dynamic_op(op))
{
count++;
}
}
return count;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment