Commit 89f1eeed authored by baojun's avatar baojun Committed by Scott Cyphers

add provenance (#3792)

parent 7b618f03
...@@ -33,10 +33,13 @@ using namespace ngraph; ...@@ -33,10 +33,13 @@ using namespace ngraph;
constexpr NodeTypeInfo op::v0::Softmax::type_info; constexpr NodeTypeInfo op::v0::Softmax::type_info;
op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes) op::v0::Softmax::Softmax(const Output<Node>& arg, const AxisSet& axes)
: Op({arg, : Op({arg})
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0)})
{ {
set_argument(
1,
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector())
->output(0));
add_provenance_group_member(input_value(1).get_node_shared_ptr());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -68,9 +71,11 @@ const AxisSet op::v0::Softmax::get_axes() const ...@@ -68,9 +71,11 @@ const AxisSet op::v0::Softmax::get_axes() const
void op::v0::Softmax::set_axes(const AxisSet& axes) void op::v0::Softmax::set_axes(const AxisSet& axes)
{ {
this->input(1).replace_source_output( shared_ptr<Node> current_const = input_value(1).get_node_shared_ptr();
op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector()) shared_ptr<Node> replacement_const =
->output(0)); op::Constant::create(element::i64, Shape{axes.to_vector().size()}, axes.to_vector());
this->input(1).replace_source_output(replacement_const->output(0));
replace_provenance_group_member(current_const, replacement_const);
} }
void op::v0::Softmax::validate_and_infer_types() void op::v0::Softmax::validate_and_infer_types()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment