Commit 17af4266 authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Fix some validation errors (#1603)

parent fe676f72
......@@ -201,18 +201,29 @@ namespace ngraph
{
ostream& operator<<(ostream& out, const Node& node)
{
out << node.description() << '[' << node.get_name() << "](";
string sep = "";
for (auto arg : node.get_arguments())
{
out << sep << arg->get_name();
sep = ", ";
}
out << ")";
return out;
return out << NodeDescription(node, false);
}
}
std::ostream& Node::write_short_description(std::ostream& out) const
{
return out << get_name();
}
std::ostream& Node::write_long_description(std::ostream& out) const
{
out << description() << '[' << get_name() << "](";
string sep = "";
for (auto arg : get_arguments())
{
out << sep << NodeDescription(*arg, true);
sep = ", ";
}
out << ")";
return out;
}
size_t Node::get_output_size() const
{
return m_outputs.size();
......
......@@ -132,6 +132,8 @@ namespace ngraph
virtual bool is_commutative() { return false; }
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const;
virtual std::ostream& write_long_description(std::ostream&) const;
// TODO: Deprecate
std::deque<descriptor::Input>& get_inputs() { return m_inputs; }
......@@ -253,6 +255,25 @@ namespace ngraph
}
};
class NodeDescription
{
public:
NodeDescription(const Node& node, bool is_short)
: m_node(node)
, m_is_short(is_short)
{
}
friend std::ostream& operator<<(std::ostream& out, const NodeDescription node_description)
{
return node_description.m_is_short
? node_description.m_node.write_short_description(out)
: node_description.m_node.write_long_description(out);
}
const Node& m_node;
bool m_is_short;
};
void check_new_args_count(const Node* node, const NodeVector& new_args);
}
......
......@@ -85,7 +85,8 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
const AxisSet& broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {})
, m_initial_broadcast_axes(broadcast_axes)
{
constructor_validate_and_infer_types();
}
......@@ -96,18 +97,29 @@ shared_ptr<Node> op::BroadcastLike::copy_with_new_args(const NodeVector& new_arg
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<BroadcastLike>(new_args.at(0), new_args.at(1), m_broadcast_axes);
return make_shared<BroadcastLike>(new_args.at(0), new_args.at(1), m_initial_broadcast_axes);
}
void op::BroadcastLike::infer_shape()
{
const Shape& in_shape = get_input_shape(0);
m_shape = get_input_shape(1);
m_broadcast_axes = m_initial_broadcast_axes;
if (m_broadcast_axes.size() == 0)
{
for (size_t i = in_shape.size(); i < m_shape.size(); ++i)
for (size_t i = 0; i < m_shape.size(); ++i)
{
m_broadcast_axes.insert(i);
if (i < in_shape.size())
{
if (in_shape.at(i) == 1 && m_shape.at(i) > 1)
{
m_broadcast_axes.insert(i);
}
}
else
{
m_broadcast_axes.insert(i);
}
}
}
}
......@@ -80,6 +80,9 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override;
protected:
AxisSet m_initial_broadcast_axes;
};
}
}
......@@ -108,7 +108,6 @@ namespace ngraph
void validate_and_infer_types() override
{
Node::validate_and_infer_types();
infer_element_type();
set_output_type(0, m_element_type, m_shape);
}
......
......@@ -25,6 +25,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
: Node("GetOutputElement", {arg})
, m_n{n}
{
constructor_validate_and_infer_types();
NODE_VALIDATION_ASSERT(this, m_n < arg->get_output_size())
<< "Output at index " << m_n << " requested, but argument has only "
<< arg->get_output_size() << " outputs.";
......
......@@ -207,7 +207,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop,
[&in_bprop](std::shared_ptr<Node> node) {
if (node->get_outputs().size() == 1)
if (node->get_output_size() == 1)
{
if (in_bprop.count(node) == 0)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment