Commit 1e2a3f34 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Softmax + cross Entropy fusion for numerical Stabilization (#3669)

* - WIP fusion pattern for softmax + cross entropy

* fix compiler error

*  make summation axis integer for the fusion pattern

* - Fusion pattern for sigmoid cross entropy bprop

* WIP callback implementation for fused sigmod+crossentropy fprop

* - implemented fprop softmax+crossentropy as single layer for numerical
stabilization
- added broadcasting nodes to fix elementwise assertions

* Added unit test case for functionality test

* Move the softmax + crossentropy fusion pass to core

* i) style fix ii)added missing header

* - Added new Fused Op for Softmax + CrossEntropy
- moved the decomposition to the Softmax + CrossEntropy FusedOp

* - Add SoftmaxCrossEntropy for fused tablegen
- Add serializer support for SoftmaxCrossEntropy
- fix documentation

* Added missing json file for unit test case

* Addressed PR comment

* Addressed PR comments

* - Fix fusion string

* - Style fix

* - Added Bprop for Softmax + crossEntropy

* - added SoftmaxCrossEntropy support when soft_lable is provided
- serailizer and deserializer support for SoftmaxCrossEntropyBprop

* - Added support in decompose_op for SM+CE bprop when ignore_mask is specified

* Updated Doc strinng

* - unit test case for SoftmaxCrossEntropy backprop with soft lables
- fixed decompose_op bug in bprop

* - if soft_label=true, capture pattern only if the labels dont have one
hot encoding

* - SoftmaxCrossEntropyBprop Support if ignore_index is specified

* add serialized files for unit test

* - fix softmax + CE pattern bug
- fix softmax + CE decompose_op() bug

* - change reduction_axes to int64_t type in fprop and bprop ctor

* - add soft_labels and ignore_index attribute to SM+CE fprop ctor

* - addition asserts in unit test to ensure SM + CE fprop and bprop fusion is successful

* - move reduction_axis computation to decompose_op from ctor to relax on
dynamic shapes

* Addressd PR Comments

* - suppprt for SM+CE for ignore_index and softmax=false

* - test case for SM+CE fprop with ignore_mask, soft_labels=false
- fix bug in decompose_op

* - refactor unit test case

* - fix PDPD unit test

* broadcast delta if shape mismatches

* -fix bdcast issue in decompose_op
parent f349593d
......@@ -360,6 +360,8 @@ set (SRC
op/fused/scale_shift.hpp
op/fused/shuffle_channels.cpp
op/fused/shuffle_channels.hpp
op/fused/softmax_crossentropy.cpp
op/fused/softmax_crossentropy.hpp
op/fused/space_to_depth.cpp
op/fused/space_to_depth.hpp
op/fused/split.cpp
......
......@@ -147,6 +147,7 @@ namespace ngraph
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::SoftmaxCrossEntropy::type_info;
op::SoftmaxCrossEntropy::SoftmaxCrossEntropy(const Output<Node>& arg1,
const Output<Node>& arg2,
bool soft_label,
int64_t ignore_index)
: FusedOp({arg1, arg2})
, m_soft_label(soft_label)
, m_ignore_index(ignore_index)
{
constructor_validate_and_infer_types();
}
NodeVector op::SoftmaxCrossEntropy::decompose_op() const
{
auto input_to_normalize = input_value(0);
auto labels = input_value(1);
auto reduction_axis = input_to_normalize.get_shape().size() - 1;
auto create_mask = [&]() -> std::shared_ptr<ngraph::Node> {
// ignore mask
auto mask_constant = ngraph::op::Constant::create(
labels.get_element_type(), labels.get_shape(), {m_ignore_index});
auto not_equal = std::make_shared<ngraph::op::NotEqual>(labels, mask_constant);
auto convert =
std::make_shared<ngraph::op::Convert>(not_equal, input_to_normalize.get_element_type());
auto reshape = std::make_shared<ngraph::op::Reshape>(
convert, AxisVector{0, 1}, Shape{convert->get_shape().at(0), 1});
return reshape;
};
auto create_xe = [&](std::shared_ptr<ngraph::Node> one_hot,
std::shared_ptr<ngraph::Node> input_softmax) {
auto node_log = std::make_shared<ngraph::op::Log>(input_softmax);
auto node_mul = one_hot * node_log;
auto node_sum = std::make_shared<ngraph::op::Sum>(
node_mul, AxisSet{static_cast<size_t>(reduction_axis)});
return -node_sum;
};
if (m_soft_label)
{
// always reduces the sum on the last axis
auto max_xj = std::make_shared<ngraph::op::Max>(
input_to_normalize, AxisSet{static_cast<size_t>(reduction_axis)});
auto broadcast_max_xj = std::make_shared<ngraph::op::Broadcast>(
max_xj, input_to_normalize.get_shape(), AxisSet{1});
auto subtract =
std::make_shared<ngraph::op::Subtract>(input_to_normalize, broadcast_max_xj);
auto exp = std::make_shared<ngraph::op::Exp>(subtract);
auto sum_over_j =
std::make_shared<ngraph::op::Sum>(exp, AxisSet{static_cast<size_t>(reduction_axis)});
auto log_sum_over_j = std::make_shared<ngraph::op::Log>(sum_over_j);
auto subtract_max_xj_from_input =
std::make_shared<ngraph::op::Subtract>(input_to_normalize, broadcast_max_xj);
auto broadcast_log = std::make_shared<ngraph::op::Broadcast>(
log_sum_over_j, subtract_max_xj_from_input->get_shape(), AxisSet{1});
auto subtract_max_xj_from_input_from_log_sum_over_j =
std::make_shared<ngraph::op::Subtract>(subtract_max_xj_from_input, broadcast_log);
// insert dtype conversion if required
if (labels.get_element_type() != input_to_normalize.get_element_type())
{
labels = std::make_shared<ngraph::op::Convert>(labels,
input_to_normalize.get_element_type());
}
auto multiply = std::make_shared<ngraph::op::Multiply>(
labels, subtract_max_xj_from_input_from_log_sum_over_j);
auto sum_over_k = std::make_shared<ngraph::op::Sum>(
multiply, AxisSet{static_cast<size_t>(reduction_axis)});
auto negate_summation = std::make_shared<ngraph::op::Negative>(sum_over_k);
auto reshape = std::make_shared<ngraph::op::Reshape>(
negate_summation, AxisVector{0}, Shape{input_to_normalize.get_shape().at(0), 1});
return {reshape};
}
else
{
// we will have one_hot encoding on labels if softmax_lables = false
size_t one_hot_axis = input_to_normalize.get_shape().size() - 1;
size_t softmax_axis = input_to_normalize.get_shape().size() - 1;
auto reshape_labels =
make_shared<op::Reshape>(labels, AxisVector{0, 1}, Shape{labels.get_shape().at(0)});
auto one_hot_labels = std::make_shared<ngraph::op::OneHot>(
reshape_labels, input_to_normalize.get_shape(), one_hot_axis);
auto convert_one_hot = std::make_shared<ngraph::op::Convert>(
one_hot_labels, input_to_normalize.get_element_type());
auto mask = create_mask();
// softmax will be applied on the input to cross_entropy
auto softmax =
std::make_shared<ngraph::op::Softmax>(input_to_normalize, AxisSet{softmax_axis});
auto xe = create_xe(convert_one_hot, softmax);
auto reshape_xe = std::make_shared<ngraph::op::Reshape>(
xe, AxisVector{0}, Shape{xe->get_shape().at(0), 1});
return {reshape_xe * mask};
}
}
shared_ptr<Node> op::SoftmaxCrossEntropy::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<SoftmaxCrossEntropy>(
new_args.at(0), new_args.at(1), m_soft_label, m_ignore_index);
}
constexpr NodeTypeInfo op::SoftmaxCrossEntropyBackprop::type_info;
op::SoftmaxCrossEntropyBackprop::SoftmaxCrossEntropyBackprop(const Output<Node>& delta,
const Output<Node>& softmax,
const Output<Node>& labels,
bool soft_label,
int64_t ignore_index)
: FusedOp({delta, softmax, labels})
, m_soft_label(soft_label)
, m_ignore_index(ignore_index)
{
constructor_validate_and_infer_types();
}
void op::SoftmaxCrossEntropyBackprop::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
}
shared_ptr<Node>
op::SoftmaxCrossEntropyBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<SoftmaxCrossEntropyBackprop>(
new_args.at(0), new_args.at(1), new_args.at(2), m_soft_label, m_ignore_index);
}
NodeVector op::SoftmaxCrossEntropyBackprop::decompose_op() const
{
auto delta = input_value(0);
auto softmax = input_value(1);
auto labels = input_value(2);
size_t one_hot_axis = delta.get_shape().size() - 1;
// always reduces the sum on the last axis
auto reduction_axis = delta.get_shape().size() - 1;
if (m_soft_label)
{
if (delta.get_shape() != labels.get_shape())
{
auto reshape = std::make_shared<ngraph::op::Reshape>(
delta, AxisVector{0, 1}, Shape{delta.get_shape().at(0)});
delta =
std::make_shared<ngraph::op::Broadcast>(reshape, labels.get_shape(), AxisSet{1});
}
auto delta_mul_labels = std::make_shared<ngraph::op::Multiply>(delta, labels);
auto summation_delta_mul_labels = std::make_shared<ngraph::op::Sum>(
delta_mul_labels, AxisSet{static_cast<size_t>(reduction_axis)});
auto broadcast_sum = std::make_shared<ngraph::op::Broadcast>(
summation_delta_mul_labels, softmax.get_shape(), AxisSet{1});
auto multiply_sm = broadcast_sum * softmax;
return {multiply_sm - delta_mul_labels};
}
else
{
// ignore mask
auto mask_constant =
ngraph::op::Constant::create(element::i64, labels.get_shape(), {m_ignore_index});
auto not_equal = std::make_shared<ngraph::op::NotEqual>(labels, mask_constant);
auto convert = std::make_shared<ngraph::op::Convert>(not_equal, element::f64);
auto reshape = std::make_shared<ngraph::op::Reshape>(
convert, AxisVector{0, 1}, Shape{convert->get_shape().at(0)});
auto broadcast_mask =
std::make_shared<ngraph::op::Broadcast>(reshape, softmax.get_shape(), AxisSet{1});
// one hot encoding of labels
auto reshape_labels =
make_shared<op::Reshape>(labels, AxisVector{0, 1}, Shape{labels.get_shape().at(0)});
auto one_hot =
std::make_shared<ngraph::op::OneHot>(reshape_labels, softmax.get_shape(), one_hot_axis);
auto convert_one_hot = std::make_shared<ngraph::op::Convert>(one_hot, element::f64);
if (delta.get_shape() != convert_one_hot->get_shape())
{
auto reshape = std::make_shared<ngraph::op::Reshape>(
delta, AxisVector{0, 1}, Shape{delta.get_shape().at(0)});
delta = std::make_shared<ngraph::op::Broadcast>(
reshape, convert_one_hot->get_shape(), AxisSet{1});
}
// (cross_entr * delta * mask)
auto delta_mul_labels = std::make_shared<ngraph::op::Multiply>(delta, convert_one_hot);
auto multiply_mask =
std::make_shared<ngraph::op::Multiply>(delta_mul_labels, broadcast_mask);
// sum (cross_entr * delta * mask)
auto summation_delta_mul_labels = std::make_shared<ngraph::op::Sum>(
multiply_mask, AxisSet{static_cast<size_t>(reduction_axis)});
auto broadcast_sum = std::make_shared<ngraph::op::Broadcast>(
summation_delta_mul_labels, softmax.get_shape(), AxisSet{1});
auto multiply_sm_with_summation = broadcast_sum * softmax;
return {multiply_sm_with_summation - multiply_mask};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
class SoftmaxCrossEntropy : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"SoftmaxCrossEntropy", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
SoftmaxCrossEntropy() = default;
/// \brief Softamax + CrossEntropy for numerical stabilization
/// \param arg1 Node that produces the tensor to normalize
/// \param arg2 Node that produces ground truth lables for the input
/// \param soft_label flag indicating whether to interpretate the given labels as soft
/// labels
/// \param ignore_index Specifies a target value that is ignored and does not contribute
/// to the input gradient Only valid if soft_label is set to False
SoftmaxCrossEntropy(const Output<Node>& arg1,
const Output<Node>& arg2,
bool soft_label = false,
int64_t ignore_index = -100);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_soft_label() const { return m_soft_label; }
int64_t get_ignore_index() const { return m_ignore_index; }
private:
bool m_soft_label;
int64_t m_ignore_index;
};
class SoftmaxCrossEntropyBackprop : public util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"SoftmaxCrossEntropyBackprop", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
SoftmaxCrossEntropyBackprop() = default;
/// \brief Backprop for SoftmaxCrossEntropy
/// \param delta Node that produces the delta during bprop
/// \param softmax Node that produces softmax from fprop
/// \param labels Node that produces ground truth labels for input
/// \param soft_label flag indicating whether to interpretate the given labels as soft
/// labels
/// \param ignore_index Specifies a target value that is ignored and does not contribute
/// to the input gradient Only valid if soft_label is set to False
SoftmaxCrossEntropyBackprop(const Output<Node>& delta,
const Output<Node>& softmax,
const Output<Node>& labels,
bool soft_label = false,
int64_t ignore_index = -100);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_soft_label() const { return m_soft_label; }
int64_t get_ignore_index() const { return m_ignore_index; }
private:
bool m_soft_label;
int64_t m_ignore_index;
};
} // namespace op
} // namespace ngraph
......@@ -53,5 +53,7 @@ NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
......@@ -25,14 +25,20 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
......@@ -41,6 +47,7 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
......@@ -55,6 +62,173 @@ static shared_ptr<Node> construct_constant_node(int n)
return op::Constant::create(element::f32, Shape{}, {n});
}
void pass::CoreFusion::construct_softmax_cross_entropy_fprop()
{
auto param_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{41, 37});
auto softmax = std::make_shared<ngraph::op::Softmax>(param_1, AxisSet{1});
// parameter with one-hot encoded values
auto param_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{41, 37});
auto log = std::make_shared<ngraph::op::Log>(softmax);
auto multiply = std::make_shared<ngraph::op::Multiply>(param_2, log);
auto reduction_axes = ngraph::op::Constant::create(element::i64, Shape{}, {1});
auto reduction_axes_label = std::make_shared<pattern::op::Label>(reduction_axes);
auto sum = std::make_shared<ngraph::op::Sum>(multiply, reduction_axes_label);
auto negative = std::make_shared<ngraph::op::Negative>(sum);
auto reshape = std::make_shared<ngraph::op::Reshape>(negative, AxisVector{0}, Shape{41, 1});
auto callback = [reduction_axes_label, param_1, param_2](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_softmax_cross_entropy_fprop against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto input_to_normalize = pattern_map[param_1];
auto labels = pattern_map[param_2];
auto softmax_crossentropy =
std::make_shared<ngraph::op::SoftmaxCrossEntropy>(input_to_normalize, labels, true);
ngraph::replace_node(m.get_match_root(), softmax_crossentropy);
return true;
};
auto m = std::make_shared<pattern::Matcher>(reshape, "CoreFusion.SoftmaxCrossEntropy");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_softmax_cross_entropy_bprop_with_soft_labels()
{
// Softmax bprop
auto input_x = std::make_shared<pattern::op::Label>(element::f32, Shape{41, 37});
auto constant_1 = ngraph::op::Constant::create(element::i64, Shape{1}, {1});
auto max_x = std::make_shared<ngraph::op::Max>(input_x, constant_1);
auto broadcast_max_x =
std::make_shared<ngraph::op::Broadcast>(max_x, Shape{41, 37}, AxisSet{1});
auto subtract_input_x = std::make_shared<ngraph::op::Subtract>(input_x, broadcast_max_x);
auto constant_2 = ngraph::op::Constant::create(element::f32, Shape{41, 37}, {1});
auto maximum = std::make_shared<ngraph::op::Maximum>(constant_2, subtract_input_x);
auto softmax_axes = ngraph::op::Constant::create(element::i64, Shape{1}, {1});
auto softmax = std::make_shared<ngraph::op::Softmax>(maximum, softmax_axes);
auto softmax_label =
std::make_shared<pattern::op::Label>(softmax, nullptr, NodeVector{softmax});
// Cross Entropy Bprop
auto delta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{41, 37});
// if soft_label = true, we will not have one hot encoding on the labels,
// instead we will get labels has 2d floating point tensor
auto labels_y = std::make_shared<pattern::op::Label>(
element::f32, Shape{41, 37}, pattern::has_class<op::Parameter>());
auto negative_y = std::make_shared<ngraph::op::Negative>(labels_y);
auto multiply_ce = std::make_shared<ngraph::op::Multiply>(negative_y, delta_label);
// summation
auto divide_sm_ce = std::make_shared<ngraph::op::Divide>(multiply_ce, softmax_label);
auto multiply_sm_ce = std::make_shared<ngraph::op::Multiply>(softmax_label, divide_sm_ce);
auto reduction_axes_label = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
auto summation = std::make_shared<ngraph::op::Sum>(multiply_sm_ce, reduction_axes_label);
auto broadcast_summation =
std::make_shared<ngraph::op::Broadcast>(summation, Shape{41, 37}, AxisSet{1});
auto subtract = std::make_shared<ngraph::op::Subtract>(divide_sm_ce, broadcast_summation);
auto multiply = std::make_shared<ngraph::op::Multiply>(softmax_label, subtract);
auto callback = [input_x, delta_label, labels_y, reduction_axes_label, softmax_label](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_softmax_cross_entropy_bprop against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto input = pattern_map[input_x];
auto labels = pattern_map[labels_y];
auto delta = pattern_map[delta_label];
auto softmax = pattern_map[softmax_label];
auto sm_ce_bprop =
std::make_shared<ngraph::op::SoftmaxCrossEntropyBackprop>(delta, softmax, labels, true);
ngraph::replace_node(m.get_match_root(), sm_ce_bprop);
return true;
};
auto m = std::make_shared<pattern::Matcher>(multiply, "CoreFusion.SoftmaxCrossEntropyBprop");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_softmax_cross_entropy_bprop_with_ignore_mask()
{
// Softmax bprop
auto input_x = std::make_shared<pattern::op::Label>(element::f64, Shape{41, 37});
auto constant_1 = ngraph::op::Constant::create(element::i64, Shape{1}, {1});
auto max_x = std::make_shared<ngraph::op::Max>(input_x, constant_1);
auto broadcast_max_x =
std::make_shared<ngraph::op::Broadcast>(max_x, Shape{41, 37}, AxisSet{1});
auto subtract_input_x = std::make_shared<ngraph::op::Subtract>(input_x, broadcast_max_x);
auto constant_2 = ngraph::op::Constant::create(element::f64, Shape{41, 37}, {1});
auto maximum = std::make_shared<ngraph::op::Maximum>(constant_2, subtract_input_x);
auto softmax_axes = ngraph::op::Constant::create(element::i64, Shape{1}, {1});
auto softmax = std::make_shared<ngraph::op::Softmax>(maximum, softmax_axes);
auto softmax_label =
std::make_shared<pattern::op::Label>(softmax, nullptr, NodeVector{softmax});
// labels
auto labels_y = std::make_shared<pattern::op::Label>(
element::i64, Shape{41, 1}, pattern::has_class<op::Parameter>());
// ignore_mask
auto mask_constant = ngraph::op::Constant::create(element::i64, Shape{41, 1}, {1});
auto mask_label = std::make_shared<pattern::op::Label>(mask_constant);
auto not_equal = std::make_shared<ngraph::op::NotEqual>(labels_y, mask_label);
auto convert = std::make_shared<ngraph::op::Convert>(not_equal, element::f64);
auto reshape = std::make_shared<ngraph::op::Reshape>(
convert, AxisVector{0, 1}, Shape{convert->get_shape().at(0)});
auto broadcast_mask =
std::make_shared<ngraph::op::Broadcast>(reshape, Shape{41, 37}, AxisSet{1});
// Cross Entropy Bprop
auto delta_label = std::make_shared<pattern::op::Label>(element::f64, Shape{41, 37});
// if ignore_mask is enabled, we will have one hot encoding on the labels,
auto reshape_labels = make_shared<op::Reshape>(labels_y, AxisVector{0, 1}, Shape{41});
auto one_hot = std::make_shared<ngraph::op::OneHot>(reshape_labels, Shape{41, 37}, size_t(1));
auto convert_one_hot = std::make_shared<ngraph::op::Convert>(one_hot, element::f64);
auto negative_y = std::make_shared<ngraph::op::Negative>(convert_one_hot);
auto multiply_ce = std::make_shared<ngraph::op::Multiply>(negative_y, delta_label);
// summation
auto divide_sm_ce = std::make_shared<ngraph::op::Divide>(multiply_ce, softmax_label);
auto multiply_mask = std::make_shared<ngraph::op::Multiply>(divide_sm_ce, broadcast_mask);
auto multiply_sm_ce = std::make_shared<ngraph::op::Multiply>(softmax_label, multiply_mask);
auto reduction_axes_label = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
auto summation = std::make_shared<ngraph::op::Sum>(multiply_sm_ce, reduction_axes_label);
auto broadcast_summation =
std::make_shared<ngraph::op::Broadcast>(summation, Shape{41, 37}, AxisSet{1});
auto subtract = std::make_shared<ngraph::op::Subtract>(multiply_mask, broadcast_summation);
auto multiply = std::make_shared<ngraph::op::Multiply>(softmax_label, subtract);
auto callback = [input_x,
delta_label,
labels_y,
reduction_axes_label,
softmax_label,
mask_label](pattern::Matcher& m) {
NGRAPH_DEBUG
<< "In a callback for construct_softmax_cross_entropy_bprop_with_ignore_mask against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto input = pattern_map[input_x];
auto labels = pattern_map[labels_y];
auto delta = pattern_map[delta_label];
auto softmax = pattern_map[softmax_label];
auto mask_constant_op =
std::static_pointer_cast<ngraph::op::Constant>(pattern_map[mask_label]);
auto ignore_index = *(static_cast<size_t const*>(mask_constant_op->get_data_ptr()));
auto sm_ce_bprop = std::make_shared<ngraph::op::SoftmaxCrossEntropyBackprop>(
delta, softmax, labels, false, ignore_index);
ngraph::replace_node(m.get_match_root(), sm_ce_bprop);
return true;
};
auto m = std::make_shared<pattern::Matcher>(multiply, "CoreFusion.SoftmaxCrossEntropyBprop");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_relu()
{
auto iconst0 = construct_constant_node(0);
......
......@@ -45,6 +45,9 @@ public:
construct_zero_padded_reshaped_conv();
construct_zero_padded_conv();
construct_zero_padded_conv_backprop_filters();
construct_softmax_cross_entropy_fprop();
construct_softmax_cross_entropy_bprop_with_soft_labels();
construct_softmax_cross_entropy_bprop_with_ignore_mask();
}
// Patterns under FOP_FUSIONS create ops (FusedOps) that might not
// be all supported by certain backends. In such a case, backends
......@@ -69,4 +72,7 @@ public:
void construct_zero_padded_conv_backprop_filters();
void construct_conv_bias();
void construct_conv_bias_add();
void construct_softmax_cross_entropy_fprop();
void construct_softmax_cross_entropy_bprop_with_soft_labels();
void construct_softmax_cross_entropy_bprop_with_ignore_mask();
};
......@@ -86,6 +86,7 @@
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -1200,7 +1201,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
return false;
}
}
if (dex)
{
auto handler = GetGlobalBuildDispatcher().find(type_index(typeid(node)));
......@@ -1243,6 +1243,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(RecurrentReshapeElimination, false, ngraph::pass)
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS)
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported)
REGISTER_KNOBBED_PASS(CPUPreFusion, true, runtime::cpu::pass)
// Disable CPUFusion if MLIR is enabled to preserve core ops.
......
......@@ -91,6 +91,7 @@
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
......@@ -2169,6 +2170,21 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
break;
}
case OP_TYPEID::SoftmaxCrossEntropy:
{
auto soft_label = node_js.at("soft_label");
auto ignore_index = node_js.at("ignore_index");
node = make_shared<op::SoftmaxCrossEntropy>(args[0], args[1], soft_label, ignore_index);
break;
}
case OP_TYPEID::SoftmaxCrossEntropyBackprop:
{
auto soft_label = node_js.at("soft_label");
auto ignore_index = node_js.at("ignore_index");
node = make_shared<op::SoftmaxCrossEntropyBackprop>(
args[0], args[1], args[2], soft_label, ignore_index);
break;
}
case OP_TYPEID::SpaceToDepth:
{
auto block_size = node_js.at("block_size").get<size_t>();
......@@ -3481,6 +3497,20 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::SoftmaxCrossEntropy:
{
auto tmp = static_cast<const op::SoftmaxCrossEntropy*>(&n);
node["soft_label"] = tmp->get_soft_label();
node["ignore_index"] = tmp->get_ignore_index();
break;
}
case OP_TYPEID::SoftmaxCrossEntropyBackprop:
{
auto tmp = static_cast<const op::SoftmaxCrossEntropyBackprop*>(&n);
node["soft_label"] = tmp->get_soft_label();
node["ignore_index"] = tmp->get_ignore_index();
break;
}
case OP_TYPEID::Tan: { break;
}
case OP_TYPEID::Tanh: { break;
......
......@@ -702,3 +702,154 @@ TEST(batch_fusion, pass_property)
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
#ifndef NGRAPH_JSON_DISABLE
TEST(core_fusion, softmax_crossentropy_fprop_1)
{
const std::string file_name("paddlepaddle/ngraph-paddlepaddle-function3.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<double> rng(-1.0, 1.0);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
// during this optimization for numeric stability we will reduce softmax operation to
// - summation (labels (input - max(input) - log (summation(exp ^ (input - max(input)))
// count_of(softmax) should be equal to zero if fusion is successful
size_t softmax = count_ops_of_type<op::Softmax>(cpu_f);
ASSERT_EQ(softmax, 0);
}
TEST(core_fusion, softmax_crossentropy_fprop_2)
{
const std::string file_name("paddlepaddle/ngraph-paddlepaddle-function1.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<double> rng(-1.0, 1.0);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
// during this optimization for numeric stability we will reduce softmax operation to
// - summation (labels (input - max(input) - log (summation(exp ^ (input - max(input)))
// count_of(softmax) should be equal to zero if fusion is successful
size_t softmax = count_ops_of_type<op::Softmax>(cpu_f);
ASSERT_EQ(softmax, 0);
}
TEST(core_fusion, softmax_crossentropy_bprop_with_soft_labels)
{
const std::string file_name("paddlepaddle/ngraph-paddlepaddle-bprop0.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<double> rng(-1.0, 1.0);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
// during this optimization for numeric stability we will eliminate (softmax / softmax)
// the number of div operator for cpu_f should be zero if the fusion is valid
size_t divide = count_ops_of_type<op::Divide>(cpu_f);
ASSERT_EQ(divide, 0);
}
TEST(core_fusion, softmax_crossentropy_bprop_with_ignore_mask)
{
const std::string file_name("paddlepaddle/ngraph-paddlepaddle-bprop1.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<double> rng(-1.0, 1.0);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
// during this optimization for numeric stability we will eliminate (softmax / softmax)
// the number of div operator for cpu_f should be zero if the fusion is valid
size_t divide = count_ops_of_type<op::Divide>(cpu_f);
ASSERT_EQ(divide, 0);
}
#endif
void test_softmax_crossentropy(Shape input_shape,
Shape label_shape,
bool soft_label,
int64_t ignore_index)
{
auto input = std::make_shared<op::Parameter>(element::f64, input_shape);
auto labels = std::make_shared<op::Parameter>(element::i64, label_shape);
auto sm_ce = std::make_shared<op::SoftmaxCrossEntropy>(input, labels, soft_label, ignore_index);
auto cpu_f = make_shared<Function>(sm_ce, ParameterVector{input, labels});
test::Uniform<double> rng(-1.0, 1.0);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto cpu_results = execute(cpu_f, args, "CPU");
// if softlabels = flase, we will have one one hot encoding for labels
if (!soft_label)
{
size_t onehot = count_ops_of_type<op::OneHot>(cpu_f);
ASSERT_EQ(onehot, 1);
}
if (ignore_index >= 0 && !soft_label)
// check for the mask
{
size_t not_equal = count_ops_of_type<op::NotEqual>(cpu_f);
ASSERT_EQ(not_equal, 1);
}
}
TEST(core_fusion, softmax_crossentropy)
{
test_softmax_crossentropy(Shape{41, 37}, Shape{41, 37}, true, -1);
test_softmax_crossentropy(Shape{41, 37}, Shape{41, 1}, false, 5);
}
[
{
"name": "Function_2",
"ops": [
{
"cacheable": true,
"element_type": "double",
"name": "Parameter_1571",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_1571_0"
],
"shape": [
41,
37
]
},
{
"cacheable": true,
"element_type": "double",
"name": "Parameter_1572",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_1572_0"
],
"shape": [
41,
37
]
},
{
"inputs": [
"Parameter_1571"
],
"name": "Negative_1607",
"op": "Negative",
"op_version": 0,
"outputs": [
"Negative_1607_0"
]
},
{
"element_type": "double",
"name": "Constant_1597",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_1597_0"
],
"shape": [
1
],
"value": [
"1"
]
},
{
"element_type": "double",
"name": "Constant_1598",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_1598_0"
],
"shape": [
1
],
"value": [
"41"
]
},
{
"inputs": [
"Constant_1597",
"Constant_1598"
],
"name": "Divide_1601",
"op": "Divide",
"op_version": 0,
"outputs": [
"Divide_1601_0"
],
"pythondiv": true
},
{
"axes": [
0,
1
],
"inputs": [
"Divide_1601"
],
"name": "Broadcast_1602",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_1602_0"
],
"shape": [
41,
1,
1
]
},
{
"input_order": [
0,
1,
2
],
"inputs": [
"Broadcast_1602"
],
"name": "Reshape_1603",
"op": "Reshape",
"op_version": 0,
"output_shape": [
41,
1
],
"outputs": [
"Reshape_1603_0"
]
},
{
"element_type": "double",
"name": "Constant_1600",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_1600_0"
],
"shape": [
41,
1
],
"value": [
"0"
]
},
{
"inputs": [
"Reshape_1603",
"Constant_1600"
],
"name": "Add_1604",
"op": "Add",
"op_version": 0,
"outputs": [
"Add_1604_0"
]
},
{
"input_order": [
0,
1
],
"inputs": [
"Add_1604"
],
"name": "Reshape_1605",
"op": "Reshape",
"op_version": 0,
"output_shape": [
41
],
"outputs": [
"Reshape_1605_0"
]
},
{
"axes": [
1
],
"inputs": [
"Reshape_1605"
],
"name": "Broadcast_1606",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_1606_0"
],
"shape": [
41,
37
]
},
{
"inputs": [
"Negative_1607",
"Broadcast_1606"
],
"name": "Multiply_1608",
"op": "Multiply",
"op_version": 0,
"outputs": [
"Multiply_1608_0"
]
},
{
"element_type": "double",
"name": "Constant_1583",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_1583_0"
],
"shape": [
41,
37
],
"value": [
"-64"
]
},
{
"element_type": "int64_t",
"name": "Constant_1579",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_1579_0"
],
"shape": [
1
],
"value": [
"1"
]
},
{
"inputs": [
"Parameter_1572",
"Constant_1579"
],
"name": "Max_1580",
"op": "Max",
"op_version": 0,
"outputs": [
"Max_1580_0"
],
"reduction_axes": [
1
]
},
{
"axes": [
1
],
"inputs": [
"Max_1580"
],
"name": "Broadcast_1581",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_1581_0"
],
"shape": [
41,
37
]
},
{
"inputs": [
"Parameter_1572",
"Broadcast_1581"
],
"name": "Subtract_1582",
"op": "Subtract",
"op_version": 0,
"outputs": [
"Subtract_1582_0"
]
},
{
"inputs": [
"Constant_1583",
"Subtract_1582"
],
"name": "Maximum_1584",
"op": "Maximum",
"op_version": 0,
"outputs": [
"Maximum_1584_0"
]
},
{
"inputs": [
"Maximum_1584"
],
"name": "Softmax_1585",
"op": "Softmax",
"op_version": 0,
"outputs": [
"Softmax_1585_0"
],
"softmax_axes": [
1
]
},
{
"inputs": [
"Multiply_1608",
"Softmax_1585"
],
"name": "Divide_1609",
"op": "Divide",
"op_version": 0,
"outputs": [
"Divide_1609_0"
],
"pythondiv": true
},
{
"inputs": [
"Softmax_1585",
"Divide_1609"
],
"name": "Multiply_1610",
"op": "Multiply",
"op_version": 0,
"outputs": [
"Multiply_1610_0"
]
},
{
"element_type": "int64_t",
"name": "Constant_1611",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_1611_0"
],
"shape": [
1
],
"value": [
"1"
]
},
{
"inputs": [
"Multiply_1610",
"Constant_1611"
],
"name": "Sum_1612",
"op": "Sum",
"op_version": 0,
"outputs": [
"Sum_1612_0"
],
"reduction_axes": [
1
]
},
{
"axes": [
1
],
"inputs": [
"Sum_1612"
],
"name": "Broadcast_1613",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_1613_0"
],
"shape": [
41,
37
]
},
{
"inputs": [
"Divide_1609",
"Broadcast_1613"
],
"name": "Subtract_1614",
"op": "Subtract",
"op_version": 0,
"outputs": [
"Subtract_1614_0"
]
},
{
"inputs": [
"Subtract_1614",
"Softmax_1585"
],
"name": "Multiply_1615",
"op": "Multiply",
"op_version": 0,
"outputs": [
"Multiply_1615_0"
]
},
{
"inputs": [
"Multiply_1615"
],
"name": "Result_1616",
"needs_default_layout": true,
"op": "Result",
"op_version": 0,
"outputs": [
"Result_1616_0"
]
}
],
"parameters": [
"Parameter_1571",
"Parameter_1572"
],
"result": [
"Result_1616"
]
}
]
[{"name":"Function_4","ops":[{"cacheable":true,"element_type":"int64_t","name":"Parameter_3133","op":"Parameter","op_version":0,"outputs":["Parameter_3133_0"],"shape":[41,1]},{"cacheable":true,"element_type":"double","name":"Parameter_3134","op":"Parameter","op_version":0,"outputs":["Parameter_3134_0"],"shape":[41,37]},{"input_order":[0,1],"inputs":["Parameter_3133"],"name":"Reshape_3179","op":"Reshape","op_version":0,"output_shape":[41],"outputs":["Reshape_3179_0"]},{"inputs":["Reshape_3179"],"name":"OneHot_3180","one_hot_axis":1,"op":"OneHot","op_version":0,"outputs":["OneHot_3180_0"],"shape":[41,37]},{"inputs":["OneHot_3180"],"name":"Convert_3183","op":"Convert","op_version":0,"outputs":["Convert_3183_0"],"target_type":"double"},{"inputs":["Convert_3183"],"name":"Negative_3184","op":"Negative","op_version":0,"outputs":["Negative_3184_0"]},{"element_type":"double","name":"Constant_3166","op":"Constant","op_version":0,"outputs":["Constant_3166_0"],"shape":[1],"value":["1"]},{"element_type":"double","name":"Constant_3167","op":"Constant","op_version":0,"outputs":["Constant_3167_0"],"shape":[1],"value":["41"]},{"inputs":["Constant_3166","Constant_3167"],"name":"Divide_3170","op":"Divide","op_version":0,"outputs":["Divide_3170_0"],"pythondiv":true},{"axes":[0,1],"inputs":["Divide_3170"],"name":"Broadcast_3171","op":"Broadcast","op_version":0,"outputs":["Broadcast_3171_0"],"shape":[41,1,1]},{"input_order":[0,1,2],"inputs":["Broadcast_3171"],"name":"Reshape_3172","op":"Reshape","op_version":0,"output_shape":[41,1],"outputs":["Reshape_3172_0"]},{"element_type":"double","name":"Constant_3169","op":"Constant","op_version":0,"outputs":["Constant_3169_0"],"shape":[41,1],"value":["0"]},{"inputs":["Reshape_3172","Constant_3169"],"name":"Add_3173","op":"Add","op_version":0,"outputs":["Add_3173_0"]},{"input_order":[0,1],"inputs":["Add_3173"],"name":"Reshape_3181","op":"Reshape","op_version":0,"output_shape":[41],"outputs":["Reshape_3181_0"]},{"axes":[1],"inputs":["Reshape_3181"],"name":"Broadcast_3182","op":"Broadcast","op_version":0,"outputs":["Broadcast_3182_0"],"shape":[41,37]},{"inputs":["Negative_3184","Broadcast_3182"],"name":"Multiply_3185","op":"Multiply","op_version":0,"outputs":["Multiply_3185_0"]},{"element_type":"double","name":"Constant_3145","op":"Constant","op_version":0,"outputs":["Constant_3145_0"],"shape":[41,37],"value":["-64"]},{"element_type":"int64_t","name":"Constant_3141","op":"Constant","op_version":0,"outputs":["Constant_3141_0"],"shape":[1],"value":["1"]},{"inputs":["Parameter_3134","Constant_3141"],"name":"Max_3142","op":"Max","op_version":0,"outputs":["Max_3142_0"],"reduction_axes":[1]},{"axes":[1],"inputs":["Max_3142"],"name":"Broadcast_3143","op":"Broadcast","op_version":0,"outputs":["Broadcast_3143_0"],"shape":[41,37]},{"inputs":["Parameter_3134","Broadcast_3143"],"name":"Subtract_3144","op":"Subtract","op_version":0,"outputs":["Subtract_3144_0"]},{"inputs":["Constant_3145","Subtract_3144"],"name":"Maximum_3146","op":"Maximum","op_version":0,"outputs":["Maximum_3146_0"]},{"inputs":["Maximum_3146"],"name":"Softmax_3147","op":"Softmax","op_version":0,"outputs":["Softmax_3147_0"],"softmax_axes":[1]},{"inputs":["Multiply_3185","Softmax_3147"],"name":"Divide_3186","op":"Divide","op_version":0,"outputs":["Divide_3186_0"],"pythondiv":true},{"element_type":"int64_t","name":"Constant_3174","op":"Constant","op_version":0,"outputs":["Constant_3174_0"],"shape":[41,1],"value":["5"]},{"inputs":["Parameter_3133","Constant_3174"],"name":"NotEqual_3175","op":"NotEqual","op_version":0,"outputs":["NotEqual_3175_0"]},{"inputs":["NotEqual_3175"],"name":"Convert_3176","op":"Convert","op_version":0,"outputs":["Convert_3176_0"],"target_type":"double"},{"input_order":[0,1],"inputs":["Convert_3176"],"name":"Reshape_3177","op":"Reshape","op_version":0,"output_shape":[41],"outputs":["Reshape_3177_0"]},{"axes":[1],"inputs":["Reshape_3177"],"name":"Broadcast_3178","op":"Broadcast","op_version":0,"outputs":["Broadcast_3178_0"],"shape":[41,37]},{"inputs":["Divide_3186","Broadcast_3178"],"name":"Multiply_3187","op":"Multiply","op_version":0,"outputs":["Multiply_3187_0"]},{"inputs":["Softmax_3147","Multiply_3187"],"name":"Multiply_3188","op":"Multiply","op_version":0,"outputs":["Multiply_3188_0"]},{"element_type":"int64_t","name":"Constant_3189","op":"Constant","op_version":0,"outputs":["Constant_3189_0"],"shape":[1],"value":["1"]},{"inputs":["Multiply_3188","Constant_3189"],"name":"Sum_3190","op":"Sum","op_version":0,"outputs":["Sum_3190_0"],"reduction_axes":[1]},{"axes":[1],"inputs":["Sum_3190"],"name":"Broadcast_3191","op":"Broadcast","op_version":0,"outputs":["Broadcast_3191_0"],"shape":[41,37]},{"inputs":["Multiply_3187","Broadcast_3191"],"name":"Subtract_3192","op":"Subtract","op_version":0,"outputs":["Subtract_3192_0"]},{"inputs":["Subtract_3192","Softmax_3147"],"name":"Multiply_3193","op":"Multiply","op_version":0,"outputs":["Multiply_3193_0"]},{"inputs":["Multiply_3193"],"name":"Result_3194","needs_default_layout":true,"op":"Result","op_version":0,"outputs":["Result_3194_0"]}],"parameters":["Parameter_3133","Parameter_3134"],"result":["Result_3194"]}]
[
{
"name": "Function_1",
"ops": [
{
"cacheable": true,
"element_type": "int64_t",
"name": "Parameter_785",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_785_0"
],
"shape": [
41,
1
]
},
{
"cacheable": true,
"element_type": "double",
"name": "Parameter_786",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_786_0"
],
"shape": [
41,
37
]
},
{
"input_order": [
0,
1
],
"inputs": [
"Parameter_785"
],
"name": "Reshape_790",
"op": "Reshape",
"op_version": 0,
"output_shape": [
41
],
"outputs": [
"Reshape_790_0"
]
},
{
"inputs": [
"Reshape_790"
],
"name": "OneHot_791",
"one_hot_axis": 1,
"op": "OneHot",
"op_version": 0,
"outputs": [
"OneHot_791_0"
],
"shape": [
41,
37
]
},
{
"inputs": [
"OneHot_791"
],
"name": "Convert_792",
"op": "Convert",
"op_version": 0,
"outputs": [
"Convert_792_0"
],
"target_type": "double"
},
{
"inputs": [
"Parameter_786"
],
"name": "Softmax_789",
"op": "Softmax",
"op_version": 0,
"outputs": [
"Softmax_789_0"
],
"softmax_axes": [
1
]
},
{
"inputs": [
"Softmax_789"
],
"name": "Log_793",
"op": "Log",
"op_version": 0,
"outputs": [
"Log_793_0"
]
},
{
"inputs": [
"Convert_792",
"Log_793"
],
"name": "Multiply_794",
"op": "Multiply",
"op_version": 0,
"outputs": [
"Multiply_794_0"
]
},
{
"element_type": "int64_t",
"name": "Constant_795",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_795_0"
],
"shape": [
1
],
"value": [
"1"
]
},
{
"inputs": [
"Multiply_794",
"Constant_795"
],
"name": "Sum_796",
"op": "Sum",
"op_version": 0,
"outputs": [
"Sum_796_0"
],
"reduction_axes": [
1
]
},
{
"inputs": [
"Sum_796"
],
"name": "Negative_797",
"op": "Negative",
"op_version": 0,
"outputs": [
"Negative_797_0"
]
},
{
"input_order": [
0
],
"inputs": [
"Negative_797"
],
"name": "Reshape_798",
"op": "Reshape",
"op_version": 0,
"output_shape": [
41,
1
],
"outputs": [
"Reshape_798_0"
]
},
{
"inputs": [
"Reshape_798"
],
"name": "Result_799",
"needs_default_layout": true,
"op": "Result",
"op_version": 0,
"outputs": [
"Result_799_0"
]
}
],
"parameters": [
"Parameter_785",
"Parameter_786"
],
"result": [
"Result_799"
]
}
]
[
{
"name": "Function_3",
"ops": [
{
"cacheable": true,
"element_type": "double",
"name": "Parameter_2309",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_2309_0"
],
"shape": [
41,
37
]
},
{
"cacheable": true,
"element_type": "double",
"name": "Parameter_2310",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_2310_0"
],
"shape": [
41,
37
]
},
{
"inputs": [
"Parameter_2310"
],
"name": "Softmax_2313",
"op": "Softmax",
"op_version": 0,
"outputs": [
"Softmax_2313_0"
],
"softmax_axes": [
1
]
},
{
"inputs": [
"Softmax_2313"
],
"name": "Log_2314",
"op": "Log",
"op_version": 0,
"outputs": [
"Log_2314_0"
]
},
{
"inputs": [
"Parameter_2309",
"Log_2314"
],
"name": "Multiply_2315",
"op": "Multiply",
"op_version": 0,
"outputs": [
"Multiply_2315_0"
]
},
{
"element_type": "int64_t",
"name": "Constant_2316",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_2316_0"
],
"shape": [
1
],
"value": [
"1"
]
},
{
"inputs": [
"Multiply_2315",
"Constant_2316"
],
"name": "Sum_2317",
"op": "Sum",
"op_version": 0,
"outputs": [
"Sum_2317_0"
],
"reduction_axes": [
1
]
},
{
"inputs": [
"Sum_2317"
],
"name": "Negative_2318",
"op": "Negative",
"op_version": 0,
"outputs": [
"Negative_2318_0"
]
},
{
"input_order": [
0
],
"inputs": [
"Negative_2318"
],
"name": "Reshape_2319",
"op": "Reshape",
"op_version": 0,
"output_shape": [
41,
1
],
"outputs": [
"Reshape_2319_0"
]
},
{
"inputs": [
"Reshape_2319"
],
"name": "Result_2320",
"needs_default_layout": true,
"op": "Result",
"op_version": 0,
"outputs": [
"Result_2320_0"
]
}
],
"parameters": [
"Parameter_2309",
"Parameter_2310"
],
"result": [
"Result_2320"
]
}
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment