Commit 89f1eeed authored by baojun's avatar baojun Committed by Scott Cyphers

add provenance (#3792)

parent 7b618f03
......@@ -33,10 +33,13 @@ using namespace ngraph;
constexpr NodeTypeInfo op::v0::Softmax::type_info;
op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
: Op({arg,
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0)})
: Op({arg})
{
set_argument(
1,
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0));
add_provenance_group_member(input_value(1).get_node_shared_ptr());
constructor_validate_and_infer_types();
}
......@@ -68,9 +71,11 @@ const AxisSet op::v0::Softmax::get_axes() const
void op::v0::Softmax::set_axes(const AxisSet& axes)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0));
shared_ptr<Node> current_const = input_value(1).get_node_shared_ptr();
shared_ptr<Node> replacement_const =
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector());
this->input(1).replace_source_output(replacement_const->output(0));
replace_provenance_group_member(current_const, replacement_const);
}
void op::v0::Softmax::validate_and_infer_types()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment