Commit c8ac0130 authored by Yimei Sun's avatar Yimei Sun Committed by Scott Cyphers

Add get_default_value in Dot op to handle the 0 sized input case (#3834)

parent 64479eb0
......@@ -18,6 +18,7 @@
#include <memory>
#include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -201,3 +202,8 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto x_reshaped_dot_delta = make_shared<Dot>(x_reshaped, delta, I_shape.size()); // JI.IK->JK
adjoints.add_delta(y, x_reshaped_dot_delta);
}
shared_ptr<Node> op::Dot::get_default_value() const
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
......@@ -60,6 +60,8 @@ namespace ngraph
void validate_and_infer_types() override;
virtual std::shared_ptr<Node> get_default_value() const override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void set_reduction_axes_count(size_t reduction_axes_count)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment