Unverified Commit 02a8fa95 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Printing cleanup (#4223)

* Printing cleanup

Add Input/Output
Simplify element type
Simplfy Node printing  implementation, try to de-noise it a bit
Enable printing of Node*

* Adjust printing

* Add doc note

* Update src/ngraph/node.cpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Cleanup

* typo
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 1974a90d
......@@ -140,4 +140,11 @@ in the unit test for this feature.
.. _pass config: https://github.com/NervanaSystems/ngraph/blob/a4a3031bb40f19ec28704f76de39762e1f27e031/src/ngraph/pass/pass_config.cpp#L54
.. _OpenMPI Runtime Library Documentation: https://www.openmprtl.org/documentation
.. _precompiled MLIR: https://github.com/IntelAI/mlir
\ No newline at end of file
.. _precompiled MLIR: https://github.com/IntelAI/mlir
Looking at graph objects
------------------------
A number of nGraph objects can print themselves on streams. For example,``cerr << a + b`` produces
``v0::Add Add_2(Parameter_0[0]:f32{2,3}, Parameter_1[0]:f32{2,3}):(f32{2,3})`` indicating the
specific version of the op, its name, arguments, and outputs.
......@@ -604,50 +604,35 @@ const op::AutoBroadcastSpec& Node::get_autob() const
namespace ngraph
{
ostream& operator<<(ostream& out, const Node& node)
{
return out << NodeDescription(node, false);
}
}
std::ostream& Node::write_short_description(std::ostream& out) const
{
return out << get_name();
ostream& operator<<(ostream& out, const Node& node) { return node.write_description(out, 1); }
ostream& operator<<(ostream& out, const Node* node) { return node->write_description(out, 1); }
}
static std::string pretty_element_type(const element::Type& et)
std::ostream& Node::write_description(std::ostream& out, uint32_t depth) const
{
if (et.is_dynamic())
if (depth == 0)
{
return "?";
out << get_name();
}
else
{
return et.c_type_string();
}
}
std::ostream& Node::write_long_description(std::ostream& out) const
{
out << description() << '[' << get_name() << "](";
string sep = "";
for (auto arg : get_arguments())
{
out << sep << NodeDescription(*arg, true) << ": "
<< pretty_element_type(arg->get_output_element_type(0))
<< arg->get_output_partial_shape(0);
sep = ", ";
}
out << ") -> (";
sep = "";
for (size_t i = 0; i < get_output_size(); i++)
{
out << sep << pretty_element_type(get_output_element_type(i))
<< get_output_partial_shape(i);
sep = ", ";
out << "v" << get_type_info().version << "::" << get_type_info().name << " " << get_name()
<< "(";
string sep = "";
for (auto arg : input_values())
{
out << sep << arg;
sep = ", ";
}
out << ") -> (";
sep = "";
for (size_t i = 0; i < get_output_size(); i++)
{
out << sep << get_output_element_type(i) << get_output_partial_shape(i);
sep = ", ";
}
out << ")";
}
out << ")";
return out;
}
......@@ -977,6 +962,37 @@ bool Node::is_dynamic() const
return false;
}
namespace ngraph
{
std::ostream& operator<<(std::ostream& out, const Output<Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Output<const Node>& output)
{
return output.get_node()->write_description(out, 0) << "[" << output.get_index()
<< "]:" << output.get_element_type()
<< output.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
std::ostream& operator<<(std::ostream& out, const Input<const Node>& input)
{
return input.get_node()->write_description(out, 0) << ".input(" << input.get_index()
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
}
Input<Node> Node::input(size_t input_index)
{
if (input_index >= m_inputs.size())
......
......@@ -270,9 +270,11 @@ namespace ngraph
virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; }
friend NGRAPH_API std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const;
virtual std::ostream& write_long_description(std::ostream&) const;
/// \brief Writes a description of a node to a stream
/// \param os The stream; should be returned
/// \param depth How many levels of inputs to describe
/// \returns The stream os
virtual std::ostream& write_description(std::ostream& os, uint32_t depth = 0) const;
std::deque<descriptor::Input>& get_inputs() NGRAPH_DEPRECATED("use inputs() instead")
{
......@@ -539,6 +541,9 @@ namespace ngraph
using NodeTypeInfo = Node::type_info_t;
NGRAPH_API std::ostream& operator<<(std::ostream&, const Node&);
NGRAPH_API std::ostream& operator<<(std::ostream&, const Node*);
template <typename NodeType>
class Input
{
......@@ -914,6 +919,11 @@ namespace ngraph
size_t m_index{0};
};
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<Node>& input);
NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<const Node>& input);
inline Output<Node> Input<Node>::get_source_output() const
{
auto& output_descriptor = m_node->m_inputs.at(m_index).get_output();
......@@ -972,25 +982,6 @@ namespace ngraph
{
}
};
class NodeDescription
{
public:
NodeDescription(const Node& node, bool is_short)
: m_node(node)
, m_is_short(is_short)
{
}
friend std::ostream& operator<<(std::ostream& out, const NodeDescription node_description)
{
return node_description.m_is_short
? node_description.m_node.write_short_description(out)
: node_description.m_node.write_long_description(out);
}
const Node& m_node;
bool m_is_short;
};
}
#define NODE_VALIDATION_CHECK(node, ...) \
NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
......
......@@ -239,9 +239,7 @@ namespace ngraph
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{
out << "element::Type{" << obj.bitwidth() << ", " << obj.is_real() << ", " << obj.is_signed()
<< ", " << obj.is_quantized() << ", \"" << obj.c_type_string() << "\"}";
return out;
return out << obj.get_type_name();
}
bool element::Type::compatible(const element::Type& t) const
......
......@@ -192,8 +192,7 @@ TEST(type_prop, dequantize_i8_from_u8_fails)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Output element type (element::Type{8, 0, 1, 1, \"int8_t\"}) must be "
"a floating point type");
"Output element type (i8) must be a floating point type");
}
catch (...)
{
......@@ -224,10 +223,8 @@ TEST(type_prop, dequantize_f32_from_f32_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Zero point / input element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
EXPECT_HAS_SUBSTRING(error.what(),
"Zero point / input element type (f32) must be a quantized type");
}
catch (...)
{
......@@ -258,10 +255,8 @@ TEST(type_prop, dequantize_batch_zero_point_type_mismatch_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must "
"match input element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
EXPECT_HAS_SUBSTRING(error.what(),
"Zero point element type (u8) must match input element type (i8)");
}
catch (...)
{
......@@ -293,10 +288,7 @@ TEST(type_prop, dequantize_scale_type_mismatch_fails)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Scale element type (element::Type{64, 1, 1, 0, \"double\"}) must "
"match output element type (element::Type{32, 1, 1, 0, \"float\"})"
);
"Scale element type (f64) must match output element type (f32)");
}
catch (...)
{
......
......@@ -241,8 +241,7 @@ TEST(type_prop, quantize_i8_to_u8_fails)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Scale / input element type (element::Type{8, 0, 1, 1, \"int8_t\"}) "
"must be a floating point number");
"Scale / input element type (i8) must be a floating point number");
}
catch (...)
{
......@@ -275,9 +274,7 @@ TEST(type_prop, quantize_f32_to_f32_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Output element type (element::Type{32, 1, 1, 0, \"float\"}) must be a quantized type");
EXPECT_HAS_SUBSTRING(error.what(), "Output element type (f32) must be a quantized type");
}
catch (...)
{
......@@ -311,8 +308,7 @@ TEST(type_prop, quantize_batch_scale_type_mismatch_fails)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Scale element type (element::Type{64, 1, 1, 0, \"double\"}) must "
"match input element type (element::Type{32, 1, 1, 0, \"float\"})");
"Scale element type (f64) must match input element type (f32)");
}
catch (...)
{
......@@ -345,10 +341,8 @@ TEST(type_prop, quantize_zero_point_type_mismatch_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must "
"match output element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
EXPECT_HAS_SUBSTRING(error.what(),
"Zero point element type (u8) must match output element type (i8)");
}
catch (...)
{
......
......@@ -161,9 +161,7 @@ TEST(type_prop, quantized_conv_non_quantized_input_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Input element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
EXPECT_HAS_SUBSTRING(error.what(), "Input element type (f32) must be a quantized type");
}
catch (...)
{
......@@ -218,9 +216,7 @@ TEST(type_prop, quantized_conv_non_quantized_filter_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Filter element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
EXPECT_HAS_SUBSTRING(error.what(), "Filter element type (f32) must be a quantized type");
}
catch (...)
{
......@@ -387,9 +383,7 @@ TEST(type_prop, quantized_conv_input_zero_point_type_mismatch_fails)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Input Zero point element type (element::Type{8, 0, 1, 1, \"int8_t\"}) must "
"match input element type (element::Type{8, 0, 0, 1, \"uint8_t\"})");
error.what(), "Input Zero point element type (i8) must match input element type (u8)");
}
catch (...)
{
......@@ -447,8 +441,7 @@ TEST(type_prop, quantized_conv_filter_zero_point_type_mismatch_fails)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Filter Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must "
"match filter element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
"Filter Zero point element type (u8) must match filter element type (i8)");
}
catch (...)
{
......
......@@ -141,9 +141,7 @@ TEST(type_prop, quantized_dot_non_quantized_input0_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Input0 element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
EXPECT_HAS_SUBSTRING(error.what(), "Input0 element type (f32) must be a quantized type");
}
catch (...)
{
......@@ -192,9 +190,7 @@ TEST(type_prop, quantized_dot_non_quantized_input1_fails)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Input1 element type (element::Type{32, 1, 1, 0, \"float\"}) "
"must be a quantized type");
EXPECT_HAS_SUBSTRING(error.what(), "Input1 element type (f32) must be a quantized type");
}
catch (...)
{
......@@ -343,8 +339,7 @@ TEST(type_prop, quantized_dot_input0_zero_point_type_mismatch_fails)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Input0 Zero point element type (element::Type{8, 0, 1, 1, \"int8_t\"}) must "
"match input0 element type (element::Type{8, 0, 0, 1, \"uint8_t\"})");
"Input0 Zero point element type (i8) must match input0 element type (u8)");
}
catch (...)
{
......@@ -395,8 +390,7 @@ TEST(type_prop, quantized_dot_input1_zero_point_type_mismatch_fails)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Input1 Zero point element type (element::Type{8, 0, 0, 1, \"uint8_t\"}) must "
"match input1 element type (element::Type{8, 0, 1, 1, \"int8_t\"})");
"Input1 Zero point element type (u8) must match input1 element type (i8)");
}
catch (...)
{
......
......@@ -70,9 +70,7 @@ TEST(type_prop, topk_invalid_index_type)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
}
catch (...)
{
......@@ -164,9 +162,7 @@ TEST(type_prop, topk_rank_dynamic_result_et_invalid)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
}
catch (...)
{
......@@ -234,9 +230,7 @@ TEST(type_prop, topk_rank_static_dynamic_axis_oob)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
}
catch (...)
{
......@@ -262,9 +256,7 @@ TEST(type_prop, topk_rank_static_dynamic_k_unknown_axis_oob)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
}
catch (...)
{
......@@ -290,9 +282,7 @@ TEST(type_prop, topk_rank_static_dynamic_k_known_too_big)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Argument element type must be i64 or i32 (got element::Type{32, 1, 1, 0, \"float\"})");
EXPECT_HAS_SUBSTRING(error.what(), "Argument element type must be i64 or i32 (got f32)");
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment