Commit 17af4266 authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Fix some validation errors (#1603)

parent fe676f72
...@@ -201,18 +201,29 @@ namespace ngraph ...@@ -201,18 +201,29 @@ namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node)
{ {
out << node.description() << '[' << node.get_name() << "]("; return out << NodeDescription(node, false);
string sep = "";
for (auto arg : node.get_arguments())
{
out << sep << arg->get_name();
sep = ", ";
}
out << ")";
return out;
} }
} }
std::ostream& Node::write_short_description(std::ostream& out) const
{
return out << get_name();
}
std::ostream& Node::write_long_description(std::ostream& out) const
{
out << description() << '[' << get_name() << "](";
string sep = "";
for (auto arg : get_arguments())
{
out << sep << NodeDescription(*arg, true);
sep = ", ";
}
out << ")";
return out;
}
size_t Node::get_output_size() const size_t Node::get_output_size() const
{ {
return m_outputs.size(); return m_outputs.size();
......
...@@ -132,6 +132,8 @@ namespace ngraph ...@@ -132,6 +132,8 @@ namespace ngraph
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const;
virtual std::ostream& write_long_description(std::ostream&) const;
// TODO: Deprecate // TODO: Deprecate
std::deque<descriptor::Input>& get_inputs() { return m_inputs; } std::deque<descriptor::Input>& get_inputs() { return m_inputs; }
...@@ -253,6 +255,25 @@ namespace ngraph ...@@ -253,6 +255,25 @@ namespace ngraph
} }
}; };
class NodeDescription
{
public:
NodeDescription(const Node& node, bool is_short)
: m_node(node)
, m_is_short(is_short)
{
}
friend std::ostream& operator<<(std::ostream& out, const NodeDescription node_description)
{
return node_description.m_is_short
? node_description.m_node.write_short_description(out)
: node_description.m_node.write_long_description(out);
}
const Node& m_node;
bool m_is_short;
};
void check_new_args_count(const Node* node, const NodeVector& new_args); void check_new_args_count(const Node* node, const NodeVector& new_args);
} }
......
...@@ -85,7 +85,8 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe ...@@ -85,7 +85,8 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg, op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg, const std::shared_ptr<Node>& like_arg,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, broadcast_axes) : Broadcast("BroadcastLike", {arg, like_arg}, {}, {})
, m_initial_broadcast_axes(broadcast_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -96,18 +97,29 @@ shared_ptr<Node> op::BroadcastLike::copy_with_new_args(const NodeVector& new_arg ...@@ -96,18 +97,29 @@ shared_ptr<Node> op::BroadcastLike::copy_with_new_args(const NodeVector& new_arg
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<BroadcastLike>(new_args.at(0), new_args.at(1), m_broadcast_axes); return make_shared<BroadcastLike>(new_args.at(0), new_args.at(1), m_initial_broadcast_axes);
} }
void op::BroadcastLike::infer_shape() void op::BroadcastLike::infer_shape()
{ {
const Shape& in_shape = get_input_shape(0); const Shape& in_shape = get_input_shape(0);
m_shape = get_input_shape(1); m_shape = get_input_shape(1);
m_broadcast_axes = m_initial_broadcast_axes;
if (m_broadcast_axes.size() == 0) if (m_broadcast_axes.size() == 0)
{ {
for (size_t i = in_shape.size(); i < m_shape.size(); ++i) for (size_t i = 0; i < m_shape.size(); ++i)
{ {
m_broadcast_axes.insert(i); if (i < in_shape.size())
{
if (in_shape.at(i) == 1 && m_shape.at(i) > 1)
{
m_broadcast_axes.insert(i);
}
}
else
{
m_broadcast_axes.insert(i);
}
} }
} }
} }
...@@ -80,6 +80,9 @@ namespace ngraph ...@@ -80,6 +80,9 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override; void infer_shape() override;
protected:
AxisSet m_initial_broadcast_axes;
}; };
} }
} }
...@@ -108,7 +108,6 @@ namespace ngraph ...@@ -108,7 +108,6 @@ namespace ngraph
void validate_and_infer_types() override void validate_and_infer_types() override
{ {
Node::validate_and_infer_types();
infer_element_type(); infer_element_type();
set_output_type(0, m_element_type, m_shape); set_output_type(0, m_element_type, m_shape);
} }
......
...@@ -25,6 +25,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n) ...@@ -25,6 +25,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
: Node("GetOutputElement", {arg}) : Node("GetOutputElement", {arg})
, m_n{n} , m_n{n}
{ {
constructor_validate_and_infer_types();
NODE_VALIDATION_ASSERT(this, m_n < arg->get_output_size()) NODE_VALIDATION_ASSERT(this, m_n < arg->get_output_size())
<< "Output at index " << m_n << " requested, but argument has only " << "Output at index " << m_n << " requested, but argument has only "
<< arg->get_output_size() << " outputs."; << arg->get_output_size() << " outputs.";
......
...@@ -207,7 +207,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -207,7 +207,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
std::unordered_set<std::shared_ptr<Node>> in_bprop; std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, ngraph::traverse_nodes(bprop,
[&in_bprop](std::shared_ptr<Node> node) { [&in_bprop](std::shared_ptr<Node> node) {
if (node->get_outputs().size() == 1) if (node->get_output_size() == 1)
{ {
if (in_bprop.count(node) == 0) if (in_bprop.count(node) == 0)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment