Commit 2cbe42c8 authored by baojun's avatar baojun Committed by Sang Ik Lee

fix type mismatch (#4077)

parent 38c3be4a
......@@ -231,7 +231,7 @@ NodeVector op::SoftmaxCrossEntropyBackprop::decompose_op() const
auto mask_constant =
ngraph::op::Constant::create(element::i64, labels.get_shape(), {m_ignore_index});
auto not_equal = std::make_shared<ngraph::op::NotEqual>(labels, mask_constant);
auto convert = std::make_shared<ngraph::op::Convert>(not_equal, element::f64);
auto convert = std::make_shared<ngraph::op::Convert>(not_equal, delta.get_element_type());
auto reshape = std::make_shared<ngraph::op::Reshape>(
convert, AxisVector{0, 1}, Shape{convert->get_shape().at(0)});
auto broadcast_mask =
......@@ -242,7 +242,8 @@ NodeVector op::SoftmaxCrossEntropyBackprop::decompose_op() const
make_shared<op::Reshape>(labels, AxisVector{0, 1}, Shape{labels.get_shape().at(0)});
auto one_hot =
std::make_shared<ngraph::op::OneHot>(reshape_labels, softmax.get_shape(), one_hot_axis);
auto convert_one_hot = std::make_shared<ngraph::op::Convert>(one_hot, element::f64);
auto convert_one_hot =
std::make_shared<ngraph::op::Convert>(one_hot, delta.get_element_type());
if (delta.get_shape() != convert_one_hot->get_shape())
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment