Unverified Commit 02a8fa95 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Printing cleanup (#4223)

* Printing cleanup

Add Input/Output
Simplify element type
Simplfy Node printing  implementation, try to de-noise it a bit
Enable printing of Node*

* Adjust printing

* Add doc note

* Update src/ngraph/node.cpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Cleanup

* typo
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 1974a90d
...@@ -140,4 +140,11 @@ in the unit test for this feature. ...@@ -140,4 +140,11 @@ in the unit test for this feature.
.. _pass config: https://github.com/NervanaSystems/ngraph/blob/a4a3031bb40f19ec28704f76de39762e1f27e031/src/ngraph/pass/pass_config.cpp#L54 .. _pass config: https://github.com/NervanaSystems/ngraph/blob/a4a3031bb40f19ec28704f76de39762e1f27e031/src/ngraph/pass/pass_config.cpp#L54
.. _OpenMPI Runtime Library Documentation: https://www.openmprtl.org/documentation .. _OpenMPI Runtime Library Documentation: https://www.openmprtl.org/documentation
.. _precompiled MLIR: https://github.com/IntelAI/mlir .. _precompiled MLIR: https://github.com/IntelAI/mlir
\ No newline at end of file
Looking at graph objects
------------------------
A number of nGraph objects can print themselves on streams. For example,``cerr << a + b`` produces
``v0::Add Add_2(Parameter_0[0]:f32{2,3}, Parameter_1[0]:f32{2,3}):(f32{2,3})`` indicating the
specific version of the op, its name, arguments, and outputs.
...@@ -604,50 +604,35 @@ const op::AutoBroadcastSpec& Node::get_autob() const ...@@ -604,50 +604,35 @@ const op::AutoBroadcastSpec& Node::get_autob() const
namespace ngraph namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node) { return node.write_description(out, 1); }
{ ostream& operator<<(ostream& out, const Node* node) { return node->write_description(out, 1); }
return out << NodeDescription(node, false);
}
}
std::ostream& Node::write_short_description(std::ostream& out) const
{
return out << get_name();
} }
static std::string pretty_element_type(const element::Type& et) std::ostream& Node::write_description(std::ostream& out, uint32_t depth) const
{ {
if (et.is_dynamic()) if (depth == 0)
{ {
return "?"; out << get_name();
} }
else else
{ {
return et.c_type_string(); out << "v" << get_type_info().version << "::" << get_type_info().name << " " << get_name()
} << "(";
} string sep = "";
for (auto arg : input_values())
std::ostream& Node::write_long_description(std::ostream& out) const {
{ out << sep << arg;
out << description() << '[' << get_name() << "]("; sep = ", ";
string sep = ""; }
for (auto arg : get_arguments()) out << ") -> (";
{ sep = "";
out << sep << NodeDescription(*arg, true) << ": " for (size_t i = 0; i < get_output_size(); i++)
<< pretty_element_type(arg->get_output_element_type(0)) {
<< arg->get_output_partial_shape(0); out << sep << get_output_element_type(i) << get_output_partial_shape(i);
sep = ", "; sep = ", ";
} }
out << ") -> ("; out << ")";
sep = "";
for (size_t i = 0; i < get_output_size(); i++)
{
out << sep << pretty_element_type(get_output_element_type(i))
<< get_output_partial_shape(i);
sep = ", ";
} }
out << ")";
return out; return out;
} }
...@@ -977,6 +962,37 @@ bool Node::is_dynamic() const ...@@ -977,6 +962,37 @@ bool Node::is_dynamic() const
return false; return false;
} }
namespace ngraph
{
std::ostream& operator<<(std::ostream& out, const Output<Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Output<const Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<const Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
}
Input<Node> Node::input(size_t input_index) Input<Node> Node::input(size_t input_index)
{ {
if (input_index >= m_inputs.size()) if (input_index >= m_inputs.size())
......
...@@ -270,9 +270,11 @@ namespace ngraph ...@@ -270,9 +270,11 @@ namespace ngraph
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; } virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend NGRAPH_API std::ostream& operator<<(std::ostream&, const Node&); /// \brief Writes a description of a node to a stream
virtual std::ostream& write_short_description(std::ostream&) const; /// \param os The stream; should be returned
virtual std::ostream& write_long_description(std::ostream&) const; /// \param depth How many levels of inputs to describe
/// \returns The stream os
virtual std::ostream& write_description(std::ostream& os, uint32_t depth = 0) const;
std::deque<descriptor::Input>& get_inputs() NGRAPH_DEPRECATED("use inputs() instead") std::deque<descriptor::Input>& get_inputs() NGRAPH_DEPRECATED("use inputs() instead")
{ {
...@@ -539,6 +541,9 @@ namespace ngraph ...@@ -539,6 +541,9 @@ namespace ngraph
using NodeTypeInfo = Node::type_info_t; using NodeTypeInfo = Node::type_info_t;
NGRAPH_API std::ostream& operator<<(std::ostream&, const Node&);
NGRAPH_API std::ostream& operator<<(std::ostream&, const Node*);
template <typename NodeType> template <typename NodeType>
class Input class Input
{ {
...@@ -914,6 +919,11 @@ namespace ngraph ...@@ -914,6 +919,11 @@ namespace ngraph
size_t m_index{0}; size_t m_index{0};
}; };
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<Node>& input);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<const Node>& input);
inline Output<Node> Input<Node>::get_source_output() const inline Output<Node> Input<Node>::get_source_output() const
{ {
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output(); auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
...@@ -972,25 +982,6 @@ namespace ngraph ...@@ -972,25 +982,6 @@ namespace ngraph
{ {
} }
}; };
class NodeDescription
{
public:
NodeDescription(const Node& node, bool is_short)
: m_node(node)
, m_is_short(is_short)
{
}
friend std::ostream& operator<<(std::ostream& out, const NodeDescription node_description)
{
return node_description.m_is_short
? node_description.m_node.write_short_description(out)
: node_description.m_node.write_long_description(out);
}
const Node& m_node;
bool m_is_short;
};
} }
#define NODE_VALIDATION_CHECK(node, ...) \ #define NODE_VALIDATION_CHECK(node, ...) \
NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__) NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
......
...@@ -239,9 +239,7 @@ namespace ngraph ...@@ -239,9 +239,7 @@ namespace ngraph
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{ {
out << "element::Type{" << obj.bitwidth() << ", " << obj.is_real() << ", " << obj.is_signed() return out << obj.get_type_name();
<< ", " << obj.is_quantized() << ", \"" << obj.c_type_string() << "\"}";
return out;
} }
bool element::Type::compatible(const element::Type& t) const bool element::Type::compatible(const element::Type& t) const
......
...@@ -192,8 +192,7 @@ TEST(type_prop, dequantize_i8_from_u8_fails) ...@@ -192,8 +192,7 @@ TEST(type_prop, dequantize_i8_from_u8_fails)
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Output element type (element::Type{8, 0, 1, 1, \"int8_t\"}) must be " "Output element type (i8) must be a floating point type");
"a floating point type");
} }
catch (...) catch (...)
{ {
...@@ -224,10 +223,8 @@ TEST(type_prop, dequantize_f32_from_f32_fails) ...@@ -224,10 +223,8 @@ TEST(type_prop, dequantize_f32_from_f32_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), "Zero point / input element type (f32) must be a quantized type");
"Zero point / input element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
} }
catch (...) catch (...)
{ {
...@@ -258,10 +255,8 @@ TEST(type_prop, dequantize_batch_zero_point_type_mismatch_fails) ...@@ -258,10 +255,8 @@ TEST(type_prop, dequantize_batch_zero_point_type_mismatch_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), "Zero point element type (u8) must match input element type (i8)");
"Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must "
"match input element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
} }
catch (...) catch (...)
{ {
...@@ -293,10 +288,7 @@ TEST(type_prop, dequantize_scale_type_mismatch_fails) ...@@ -293,10 +288,7 @@ TEST(type_prop, dequantize_scale_type_mismatch_fails)
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Scale element type (element::Type{64, 1, 1, 0, \"double\"}) must " "Scale element type (f64) must match output element type (f32)");
"match output element type (element::Type{32, 1, 1, 0, \"float\"})"
);
} }
catch (...) catch (...)
{ {
......
...@@ -241,8 +241,7 @@ TEST(type_prop, quantize_i8_to_u8_fails) ...@@ -241,8 +241,7 @@ TEST(type_prop, quantize_i8_to_u8_fails)
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Scale / input element type (element::Type{8, 0, 1, 1, \"int8_t\"}) " "Scale / input element type (i8) must be a floating point number");
"must be a floating point number");
} }
catch (...) catch (...)
{ {
...@@ -275,9 +274,7 @@ TEST(type_prop, quantize_f32_to_f32_fails) ...@@ -275,9 +274,7 @@ TEST(type_prop, quantize_f32_to_f32_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), "Output element type (f32) must be a quantized type");
error.what(),
"Output element type (element::Type{32, 1, 1, 0, \"float\"}) must be a quantized type");
} }
catch (...) catch (...)
{ {
...@@ -311,8 +308,7 @@ TEST(type_prop, quantize_batch_scale_type_mismatch_fails) ...@@ -311,8 +308,7 @@ TEST(type_prop, quantize_batch_scale_type_mismatch_fails)
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Scale element type (element::Type{64, 1, 1, 0, \"double\"}) must " "Scale element type (f64) must match input element type (f32)");
"match input element type (element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
...@@ -345,10 +341,8 @@ TEST(type_prop, quantize_zero_point_type_mismatch_fails) ...@@ -345,10 +341,8 @@ TEST(type_prop, quantize_zero_point_type_mismatch_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), "Zero point element type (u8) must match output element type (i8)");
"Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must "
"match output element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
} }
catch (...) catch (...)
{ {
......
...@@ -161,9 +161,7 @@ TEST(type_prop, quantized_conv_non_quantized_input_fails) ...@@ -161,9 +161,7 @@ TEST(type_prop, quantized_conv_non_quantized_input_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Input element type (f32) must be a quantized type");
"Input element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
} }
catch (...) catch (...)
{ {
...@@ -218,9 +216,7 @@ TEST(type_prop, quantized_conv_non_quantized_filter_fails) ...@@ -218,9 +216,7 @@ TEST(type_prop, quantized_conv_non_quantized_filter_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Filter element type (f32) must be a quantized type");
"Filter element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
} }
catch (...) catch (...)
{ {
...@@ -387,9 +383,7 @@ TEST(type_prop, quantized_conv_input_zero_point_type_mismatch_fails) ...@@ -387,9 +383,7 @@ TEST(type_prop, quantized_conv_input_zero_point_type_mismatch_fails)
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(), "Input Zero point element type (i8) must match input element type (u8)");
"Input Zero point element type (element::Type{8, 0, 1, 1, \"int8_t\"}) must "
"match input element type (element::Type{8, 0, 0, 1, \"uint8_t\"})");
} }
catch (...) catch (...)
{ {
...@@ -447,8 +441,7 @@ TEST(type_prop, quantized_conv_filter_zero_point_type_mismatch_fails) ...@@ -447,8 +441,7 @@ TEST(type_prop, quantized_conv_filter_zero_point_type_mismatch_fails)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
"Filter Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must " "Filter Zero point element type (u8) must match filter element type (i8)");
"match filter element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
} }
catch (...) catch (...)
{ {
......
...@@ -141,9 +141,7 @@ TEST(type_prop, quantized_dot_non_quantized_input0_fails) ...@@ -141,9 +141,7 @@ TEST(type_prop, quantized_dot_non_quantized_input0_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Input0 element type (f32) must be a quantized type");
"Input0 element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
} }
catch (...) catch (...)
{ {
...@@ -192,9 +190,7 @@ TEST(type_prop, quantized_dot_non_quantized_input1_fails) ...@@ -192,9 +190,7 @@ TEST(type_prop, quantized_dot_non_quantized_input1_fails)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), "Input1 element type (f32) must be a quantized type");
"Input1 element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
} }
catch (...) catch (...)
{ {
...@@ -343,8 +339,7 @@ TEST(type_prop, quantized_dot_input0_zero_point_type_mismatch_fails) ...@@ -343,8 +339,7 @@ TEST(type_prop, quantized_dot_input0_zero_point_type_mismatch_fails)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
"Input0 Zero point element type (element::Type{8, 0, 1, 1, \"int8_t\"}) must " "Input0 Zero point element type (i8) must match input0 element type (u8)");
"match input0 element type (element::Type{8, 0, 0, 1, \"uint8_t\"})");
} }
catch (...) catch (...)
{ {
...@@ -395,8 +390,7 @@ TEST(type_prop, quantized_dot_input1_zero_point_type_mismatch_fails) ...@@ -395,8 +390,7 @@ TEST(type_prop, quantized_dot_input1_zero_point_type_mismatch_fails)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
"Input1 Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must " "Input1 Zero point element type (u8) must match input1 element type (i8)");
"match input1 element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
} }
catch (...) catch (...)
{ {
......
...@@ -70,9 +70,7 @@ TEST(type_prop, topk_invalid_index_type) ...@@ -70,9 +70,7 @@ TEST(type_prop, topk_invalid_index_type)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
...@@ -164,9 +162,7 @@ TEST(type_prop, topk_rank_dynamic_result_et_invalid) ...@@ -164,9 +162,7 @@ TEST(type_prop, topk_rank_dynamic_result_et_invalid)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
...@@ -234,9 +230,7 @@ TEST(type_prop, topk_rank_static_dynamic_axis_oob) ...@@ -234,9 +230,7 @@ TEST(type_prop, topk_rank_static_dynamic_axis_oob)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
...@@ -262,9 +256,7 @@ TEST(type_prop, topk_rank_static_dynamic_k_unknown_axis_oob) ...@@ -262,9 +256,7 @@ TEST(type_prop, topk_rank_static_dynamic_k_unknown_axis_oob)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
...@@ -290,9 +282,7 @@ TEST(type_prop, topk_rank_static_dynamic_k_known_too_big) ...@@ -290,9 +282,7 @@ TEST(type_prop, topk_rank_static_dynamic_k_known_too_big)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment