Commit a1a8a7e3 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Implementation of CrossEntropy and CrossEntropyBackprop as fused Op's (#3818)

* - Implementaion of CrossEntropy and CrossEntropyBackprop as fused Op's

* - unit test case for CE fprop
- fix bug in decompose_op

* WIP debug PDPD unit test failure

* fixed broadcasting issue

* -fix bdcast issue for multi dim tensor

* utilities to restore the original tensor shape

* i) style-fix ii) rename variables

* - unit test for multiple dimensions ii) refactor create_mask to seperate function

* - fixed unit tests

* fix style

* set output element type to dynamic in pre_validate and infer shape

* disable ce with one hot unit test on PlaidML

* add CE op to fused_op_tbl

* - add serialzier support for CE and CE Backprop
parent 19e2434a
...@@ -337,6 +337,8 @@ set (SRC ...@@ -337,6 +337,8 @@ set (SRC
op/fused/clamp.hpp op/fused/clamp.hpp
op/fused/conv_fused.cpp op/fused/conv_fused.cpp
op/fused/conv_fused.hpp op/fused/conv_fused.hpp
op/fused/crossentropy.cpp
op/fused/crossentropy.hpp
op/fused/hard_sigmoid.cpp op/fused/hard_sigmoid.cpp
op/fused/hard_sigmoid.hpp op/fused/hard_sigmoid.hpp
op/fused/depth_to_space.cpp op/fused/depth_to_space.cpp
......
...@@ -133,6 +133,7 @@ namespace ngraph ...@@ -133,6 +133,7 @@ namespace ngraph
#include "ngraph/op/floor_mod.hpp" #include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/clamp.hpp" #include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp" #include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp" #include "ngraph/op/fused/fake_quantize.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::CrossEntropy::type_info;
op::CrossEntropy::CrossEntropy(const Output<Node>& arg1,
const Output<Node>& arg2,
bool soft_label,
int64_t ignore_index)
: FusedOp({arg1, arg2})
, m_soft_label(soft_label)
, m_ignore_index(ignore_index)
{
constructor_validate_and_infer_types();
}
static AxisVector get_axis_vector(size_t rank)
{
AxisVector axis_vector;
for (size_t i = 0; i < rank; i++)
{
axis_vector.push_back(i);
}
return axis_vector;
}
static Shape get_result_shape(Shape& target_shape, int start, int end)
{
Shape result;
for (size_t i = start; i < end; i++)
{
result.push_back(target_shape[i]);
}
return result;
}
static Output<Node> get_2d_tensor(Output<Node> node)
{
if (node.get_shape().size() == 2)
{
return node;
}
Shape node_shape = node.get_shape();
size_t rank = node_shape.size();
Shape result_shape{(shape_size(node_shape) / node_shape[rank - 1]), node_shape[rank - 1]};
auto reshape = std::make_shared<ngraph::op::Reshape>(node, get_axis_vector(rank), result_shape);
return reshape;
}
static std::shared_ptr<Node> expand_shape(std::shared_ptr<Node> result, Output<Node> original)
{
Shape result_shape = result->get_shape();
Shape original_shape = original.get_shape();
if (result_shape == original_shape && result_shape.size() == 2)
{
return result;
}
size_t original_rank = original_shape.size();
size_t result_rank = result_shape.size();
// expand the first dimension of the computed result to match the original tensor shape
Shape new_shape = get_result_shape(original_shape, 0, original_rank - 1);
// restore the last dimension of computed result
new_shape.push_back(result_shape[result_rank - 1]);
if (new_shape.size() != original_shape.size())
{
throw ngraph_error(
"CrossEntropy shape size mismatch in restoring the original tensor shape");
}
auto reshape = std::make_shared<ngraph::op::Reshape>(result, AxisVector{0, 1}, new_shape);
return reshape;
}
// create mask based on ignore_index
static std::shared_ptr<ngraph::Node>
create_mask(Output<Node> labels, Output<Node> input, int64_t ignore_index)
{
auto mask_constant =
ngraph::op::Constant::create(labels.get_element_type(), labels.get_shape(), {ignore_index});
auto not_equal = std::make_shared<ngraph::op::NotEqual>(labels, mask_constant);
auto convert = std::make_shared<ngraph::op::Convert>(not_equal, input.get_element_type());
return convert;
}
NodeVector op::CrossEntropy::decompose_op() const
{
// we will reshape the labels and input tensor to 2d
auto input_to_normalize = get_2d_tensor(input_value(0));
auto labels = get_2d_tensor(input_value(1));
auto reduction_axis = input_to_normalize.get_shape().size() - 1;
auto create_xe = [&](const Output<Node>& one_hot, const Output<Node>& input) {
auto node_log = std::make_shared<ngraph::op::Log>(input);
auto node_mul = one_hot * node_log;
auto node_sum = std::make_shared<ngraph::op::Sum>(
node_mul, AxisSet{static_cast<size_t>(reduction_axis)});
return -node_sum;
};
// mask
std::shared_ptr<ngraph::Node> mask = create_mask(labels, input_to_normalize, m_ignore_index);
if (m_soft_label)
{
// insert dtype conversion if required
if (labels.get_element_type() != input_to_normalize.get_element_type())
{
labels = std::make_shared<ngraph::op::Convert>(labels,
input_to_normalize.get_element_type());
}
if (labels.get_shape()[reduction_axis] == 1)
{
auto reshape_labels = std::make_shared<ngraph::op::Reshape>(
labels, AxisVector{0, 1}, Shape{labels.get_shape().at(0)});
labels = std::make_shared<ngraph::op::Broadcast>(
reshape_labels,
input_to_normalize.get_shape(),
AxisSet{input_to_normalize.get_shape().size() - 1});
}
auto xe = create_xe(labels, input_to_normalize);
auto reshape_xe = std::make_shared<ngraph::op::Reshape>(
xe, AxisVector{0}, Shape{xe->get_shape().at(0), 1});
return {expand_shape(reshape_xe, input_value(0))};
}
else
{
// we will have one_hot encoding on labels if softmax_labels = false
size_t one_hot_axis = input_to_normalize.get_shape().size() - 1;
auto reshape_labels =
make_shared<op::Reshape>(labels, AxisVector{0, 1}, Shape{labels.get_shape().at(0)});
auto one_hot_labels = std::make_shared<ngraph::op::OneHot>(
reshape_labels, input_to_normalize.get_shape(), one_hot_axis);
auto convert_one_hot = std::make_shared<ngraph::op::Convert>(
one_hot_labels, input_to_normalize.get_element_type());
// calculate loss
auto xe = create_xe(convert_one_hot, input_to_normalize);
auto reshape_xe = std::make_shared<ngraph::op::Reshape>(
xe, AxisVector{0}, Shape{xe->get_shape().at(0), 1});
if (m_ignore_index > 0)
{
return {reshape_xe * mask};
}
return {expand_shape(reshape_xe, input_value(0))};
}
}
shared_ptr<Node> op::CrossEntropy::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CrossEntropy>(new_args.at(0), new_args.at(1), m_soft_label, m_ignore_index);
}
void op::CrossEntropy::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
if (is_dynamic())
{
return;
}
}
constexpr NodeTypeInfo op::CrossEntropyBackprop::type_info;
op::CrossEntropyBackprop::CrossEntropyBackprop(const Output<Node>& input,
const Output<Node>& labels,
const Output<Node>& delta,
bool soft_label,
int64_t ignore_index)
: FusedOp({input, labels, delta})
, m_soft_label(soft_label)
, m_ignore_index(ignore_index)
{
constructor_validate_and_infer_types();
}
void op::CrossEntropyBackprop::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
shared_ptr<Node> op::CrossEntropyBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CrossEntropyBackprop>(
new_args.at(0), new_args.at(1), new_args.at(2), m_soft_label, m_ignore_index);
}
NodeVector op::CrossEntropyBackprop::decompose_op() const
{
auto input = get_2d_tensor(input_value(0));
auto labels = get_2d_tensor(input_value(1));
auto delta = get_2d_tensor(input_value(2));
auto rank = input.get_shape().size();
size_t one_hot_axis = delta.get_shape().size() - 1;
// always reduces the sum on the last axis
auto reduction_axis = delta.get_shape().size() - 1;
// mask
std::shared_ptr<ngraph::Node> mask = nullptr;
// remove trailing ones from delta
auto delta_reshape = std::make_shared<ngraph::op::Reshape>(
delta, AxisVector{0, 1}, Shape{delta.get_shape().at(0)});
auto delta_bcast = std::make_shared<ngraph::op::Broadcast>(
delta_reshape, input.get_shape(), AxisSet{rank - 1});
if (!m_soft_label)
{
// ignore mask
if (m_ignore_index > 0)
{
mask = create_mask(labels, input, m_ignore_index);
mask = std::make_shared<ngraph::op::Reshape>(
mask, AxisVector{0, 1}, Shape{mask->get_shape().at(0)});
mask =
std::make_shared<ngraph::op::Broadcast>(mask, input.get_shape(), AxisSet{rank - 1});
}
if (labels.get_shape()[reduction_axis] == 1)
{
labels =
make_shared<op::Reshape>(labels, AxisVector{0, 1}, Shape{labels.get_shape().at(0)});
}
// one hot encoding of labels
auto one_hot =
std::make_shared<ngraph::op::OneHot>(labels, input.get_shape(), one_hot_axis);
labels = std::make_shared<ngraph::op::Convert>(one_hot, input.get_element_type());
}
std::shared_ptr<ngraph::Node> xe_grad =
std::make_shared<ngraph::op::Divide>(-labels * delta_bcast, input);
if (!m_soft_label && m_ignore_index > 0)
{
xe_grad = xe_grad * mask;
}
return {expand_shape(xe_grad, input_value(0))};
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
class CrossEntropy : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"CrossEntropy", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
CrossEntropy() = default;
/// \brief CrossEntropy for computing loss
/// \param arg1 Node that produces the input tensor
/// \param arg2 Node that produces ground truth lables for the input
/// \param soft_label flag indicating whether to interpretate the given labels as soft
/// labels
/// \param ignore_index Specifies a target value that is ignored and does not contribute
/// to the input gradient Only valid if soft_label is set to False
CrossEntropy(const Output<Node>& arg1,
const Output<Node>& arg2,
bool soft_label = false,
int64_t ignore_index = -100);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_soft_label() const { return m_soft_label; }
int64_t get_ignore_index() const { return m_ignore_index; }
private:
bool m_soft_label;
int64_t m_ignore_index;
};
class CrossEntropyBackprop : public util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"CrossEntropyBackprop", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
CrossEntropyBackprop() = default;
/// \brief Backprop for CrossEntropy
/// \param input Node that produces tensor from the fprop
/// \param labels Node that produces ground truth labels for input
/// \param delta Node that produces the delta during bprop
/// \param soft_label flag indicating whether to interpretate the given labels as soft
/// labels
/// \param ignore_index Specifies a target value that is ignored and does not contribute
/// to the input gradient Only valid if soft_label is set to False
CrossEntropyBackprop(const Output<Node>& input,
const Output<Node>& labels,
const Output<Node>& delta,
bool soft_label = false,
int64_t ignore_index = -100);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_soft_label() const { return m_soft_label; }
int64_t get_ignore_index() const { return m_ignore_index; }
private:
bool m_soft_label;
int64_t m_ignore_index;
};
} // namespace op
} // namespace ngraph
...@@ -26,6 +26,8 @@ NGRAPH_OP(Clamp, ngraph::op) ...@@ -26,6 +26,8 @@ NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op) NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op) NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op) NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op) NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op) NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op) NGRAPH_OP(FakeQuantize, ngraph::op)
......
...@@ -147,6 +147,7 @@ pad_reflect_1d_multi_reflect ...@@ -147,6 +147,7 @@ pad_reflect_1d_multi_reflect
pad_reflect_2d pad_reflect_2d
pad_reflect_2d_with_neg pad_reflect_2d_with_neg
pad_symmetric pad_symmetric
cross_entropy_with_one_hot
# No double precision FP support in PlaidML # No double precision FP support in PlaidML
sum_trivial_in_double sum_trivial_in_double
......
...@@ -73,6 +73,7 @@ ...@@ -73,6 +73,7 @@
#include "ngraph/op/floor_mod.hpp" #include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/clamp.hpp" #include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp" #include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp" #include "ngraph/op/fused/fake_quantize.hpp"
...@@ -1390,6 +1391,21 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1390,6 +1391,21 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Cosh>(args[0]); node = make_shared<op::Cosh>(args[0]);
break; break;
} }
case OP_TYPEID::CrossEntropy:
{
auto soft_label = node_js.at("soft_label");
auto ignore_index = node_js.at("ignore_index");
node = make_shared<op::CrossEntropy>(args[0], args[1], soft_label, ignore_index);
break;
}
case OP_TYPEID::CrossEntropyBackprop:
{
auto soft_label = node_js.at("soft_label");
auto ignore_index = node_js.at("ignore_index");
node = make_shared<op::CrossEntropyBackprop>(
args[0], args[1], args[2], soft_label, ignore_index);
break;
}
case OP_TYPEID::DepthToSpace: case OP_TYPEID::DepthToSpace:
{ {
auto mode = node_js.at("mode").get<op::DepthToSpace::DepthToSpaceMode>(); auto mode = node_js.at("mode").get<op::DepthToSpace::DepthToSpaceMode>();
...@@ -3279,6 +3295,20 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3279,6 +3295,20 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::Cosh: { break; case OP_TYPEID::Cosh: { break;
} }
case OP_TYPEID::CrossEntropy:
{
auto tmp = static_cast<const op::CrossEntropy*>(&n);
node["soft_label"] = tmp->get_soft_label();
node["ignore_index"] = tmp->get_ignore_index();
break;
}
case OP_TYPEID::CrossEntropyBackprop:
{
auto tmp = static_cast<const op::CrossEntropyBackprop*>(&n);
node["soft_label"] = tmp->get_soft_label();
node["ignore_index"] = tmp->get_ignore_index();
break;
}
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
{ {
auto tmp = static_cast<const op::Dequantize*>(&n); auto tmp = static_cast<const op::Dequantize*>(&n);
......
...@@ -2548,3 +2548,47 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function) ...@@ -2548,3 +2548,47 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, cross_entropy_with_soft_labels)
{
Shape tensor_shape{2, 4};
auto input = make_shared<op::Parameter>(element::f32, tensor_shape);
auto labels = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto cross_entropy = make_shared<op::CrossEntropy>(input, labels, true);
auto f0 = make_shared<Function>(NodeVector{cross_entropy}, ParameterVector{input, labels});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, tensor_shape);
copy_data(a, vector<float>{0.25f, 0.25f, 0.25f, 0.25f, 0.01f, 0.01f, 0.01f, 0.96f});
auto b = backend->create_tensor(element::i32, Shape{2, 4});
copy_data(b, vector<int32_t>{0, 0, 0, 1, 0, 0, 0, 1});
auto result0 = backend->create_tensor(element::f32, Shape{2, 1});
auto handle = backend->compile(f0);
handle->call_with_validate({result0}, {a, b});
vector<float> expected{1.38629f, 0.040822f};
auto result = read_vector<float>(result0);
EXPECT_TRUE(test::all_close_f(result, expected, 23));
}
NGRAPH_TEST(${BACKEND_NAME}, cross_entropy_with_one_hot)
{
Shape tensor_shape{2, 4};
auto input = make_shared<op::Parameter>(element::f32, tensor_shape);
auto labels = make_shared<op::Parameter>(element::i32, Shape{2, 1});
auto cross_entropy = make_shared<op::CrossEntropy>(input, labels, false);
auto f0 = make_shared<Function>(NodeVector{cross_entropy}, ParameterVector{input, labels});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, tensor_shape);
copy_data(a, vector<float>{0.25f, 0.25f, 0.25f, 0.25f, 0.01f, 0.01f, 0.01f, 0.96f});
auto b = backend->create_tensor(element::i32, Shape{2, 1});
copy_data(b, vector<int32_t>{1, 1});
auto result0 = backend->create_tensor(element::f32, Shape{2, 1});
auto handle = backend->compile(f0);
handle->call_with_validate({result0}, {a, b});
vector<float> expected{1.38629f, 4.60517f};
auto result = read_vector<float>(result0);
EXPECT_TRUE(test::all_close_f(result, expected, 23));
}
...@@ -853,3 +853,42 @@ TEST(core_fusion, softmax_crossentropy) ...@@ -853,3 +853,42 @@ TEST(core_fusion, softmax_crossentropy)
test_softmax_crossentropy(Shape{41, 37}, Shape{41, 37}, true, -1); test_softmax_crossentropy(Shape{41, 37}, Shape{41, 37}, true, -1);
test_softmax_crossentropy(Shape{41, 37}, Shape{41, 1}, false, 5); test_softmax_crossentropy(Shape{41, 37}, Shape{41, 1}, false, 5);
} }
void test_crossentropy(Shape input_shape, Shape label_shape, bool soft_label, int64_t ignore_index)
{
auto input = std::make_shared<op::Parameter>(element::f64, input_shape);
auto labels = std::make_shared<op::Parameter>(element::i64, label_shape);
auto sm_ce = std::make_shared<op::CrossEntropy>(input, labels, soft_label, ignore_index);
auto cpu_f = make_shared<Function>(sm_ce, ParameterVector{input, labels});
test::Uniform<double> rng(-1.0, 1.0);
vector<vector<double>> args;
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{
vector<double> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto cpu_results = execute(cpu_f, args, "CPU");
// if softlabels = flase, we will have one one hot encoding for labels
if (!soft_label)
{
size_t onehot = count_ops_of_type<op::OneHot>(cpu_f);
ASSERT_EQ(onehot, 1);
}
if (ignore_index >= 0 && !soft_label)
// check for the mask
{
size_t not_equal = count_ops_of_type<op::NotEqual>(cpu_f);
ASSERT_EQ(not_equal, 1);
}
}
TEST(core_fusion, crossentropy)
{
test_crossentropy(Shape{41, 37}, Shape{41, 37}, true, -1);
test_crossentropy(Shape{41, 37}, Shape{41, 1}, false, 5);
test_crossentropy(Shape{10, 2, 4, 10}, Shape{10, 2, 4, 1}, false, 5);
test_crossentropy(Shape{4, 3, 2, 4}, Shape{4, 3, 2, 4}, true, -1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment