Unverified Commit cc754735 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Fix incorrect uses of `description()` (#3946)

* Fix incorrect uses of `description()`

* type-o/namespace
parent 77a99b30
......@@ -91,15 +91,8 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
// We apply the same general-purposes passes as the CPU backend.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>([](const Node& node) -> bool {
if (node.description() == ngraph::op::GroupConvolution().description())
{
return true;
}
else if (node.description() == ngraph::op::LayerNorm().description())
{
return true;
}
return false;
return ngraph::is_type<op::GroupConvolution>(&node) ||
ngraph::is_type<op::LayerNorm>(&node);
});
pass_manager.register_pass<ngraph::pass::Opset0Downgrade>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
......
......@@ -34,6 +34,7 @@
#include "ngraph/except.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
......@@ -161,29 +162,6 @@ void print_results(vector<PerfShape> perf_data, bool timing_detail)
}
}
element::Type get_op_element_type(const Node& op)
{
element::Type type;
if (op.description() == "Convert")
{
type = op.input(0).get_element_type();
}
else if (op.description() == "Equal" || op.description() == "Greater" ||
op.description() == "GreaterEq" || op.description() == "Less" ||
op.description() == "LessEq" || op.description() == "NotEqual")
{
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op.input(1).get_element_type();
}
else
{
type = op.output(0).get_element_type();
}
return type;
}
int main(int argc, char** argv)
{
string model_arg;
......@@ -373,6 +351,10 @@ OPTIONS
set<string> type_list;
for (shared_ptr<Node> node : f->get_ordered_ops())
{
for (auto value : node->outputs())
{
type_list.insert(value.get_element_type().c_type_string());
}
for (descriptor::Tensor* tensor : node->liveness_new_list)
{
total_temporary_bytes += tensor->size();
......@@ -381,11 +363,8 @@ OPTIONS
string op_name = node->description();
string shape_name = "{" + join(node->output(0).get_shape()) + "}";
op_list[op_name + shape_name]++;
auto et = get_op_element_type(*node);
string type_string = et.c_type_string();
type_list.insert(type_string);
if (op_name == "Constant")
if (node->is_constant())
{
total_constant_count++;
const Shape& shape = node->output(0).get_shape();
......@@ -400,14 +379,14 @@ OPTIONS
(const_size * shape_size(node->output(0).get_shape()));
}
}
else if (op_name == "Parameter")
else if (node->is_parameter())
{
total_parameter_count++;
const Shape& shape = node->output(0).get_shape();
size_t size = node->output(0).get_element_type().size() * shape_size(shape);
total_parameter_bytes += size;
}
else if (op_name == "Result")
else if (is_type<op::Result>(node))
{
total_result_count++;
const Shape& shape = node->input(0).get_shape();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment