Unverified Commit e4b90a9c authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Batchnorm Bprop v2 (#567)

* one output

multiple outputs

initial clean-up

* test clean-up

current version

test pass

* clean up

* fix format

* add dbeta,dgamma asserts

* revert some files

* 0644 on node.cpp

* 0644 on mkldnn_utils.cpp

* 0644 on more files

* add support for serialization + test case

* fix merge errors
parent 355bff8f
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/get_output_element.hpp"
ngraph::op::BatchNorm::BatchNorm(double eps, ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma, std::shared_ptr<ngraph::Node> gamma,
...@@ -94,3 +95,94 @@ std::shared_ptr<ngraph::Node> ...@@ -94,3 +95,94 @@ std::shared_ptr<ngraph::Node>
return std::make_shared<BatchNorm>( return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4)); m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
} }
ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> delta)
: RequiresTensorViewArgs("BatchNormBackprop", {gamma, beta, input, mean, variance, delta})
, epsilon(eps)
{
if (input->get_shape().size() != 4)
{
throw ngraph_error("Input expected to be a 4D tensor");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta", "input", "mean", "variance", "delta"};
for (size_t i = 0; i < get_input_size(); i++)
{
if (get_input_op(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
Shape channel_shape{input->get_shape().at(1)};
for (size_t i = 0; i < get_input_size(); i++)
{
if (i == 2 || i == 5) //don't check input and delta
{
continue;
}
if (get_input_op(i)->get_shape() != channel_shape)
{
auto err_msg = std::string("The shape of ") + input_names[i] +
" isn't equal to input channel's shape";
throw ngraph_error(err_msg.c_str());
}
}
if (delta->get_shape() != input->get_shape())
{
throw ngraph_error("delta shape is expected to be equal to input shape");
}
add_output(input->get_element_type(), input->get_shape());
add_output(gamma->get_element_type(), gamma->get_shape());
add_output(beta->get_element_type(), beta->get_shape());
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormBackprop::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 6)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<op::BatchNormBackprop>(epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5));
}
void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto gamma = get_input_op(0);
auto beta = get_input_op(1);
auto input = get_input_op(2);
auto mean = get_input_op(3);
auto variance = get_input_op(4);
auto bbn = std::make_shared<op::BatchNormBackprop>(
get_eps_value(), gamma, beta, input, mean, variance, delta);
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);
auto dgamma = std::make_shared<op::GetOutputElement>(bbn, 1);
auto dbeta = std::make_shared<op::GetOutputElement>(bbn, 2);
adjoints.add_delta(input, dinput);
adjoints.add_delta(gamma, dgamma);
adjoints.add_delta(beta, dbeta);
}
...@@ -44,11 +44,34 @@ namespace ngraph ...@@ -44,11 +44,34 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
private: private:
Shape m_bn_input_shape; Shape m_bn_input_shape;
Shape m_bn_variance_shape; Shape m_bn_variance_shape;
Shape m_bn_mean_shape; Shape m_bn_mean_shape;
double m_epsilon; double m_epsilon;
}; };
class BatchNormBackprop : public util::RequiresTensorViewArgs
{
public:
BatchNormBackprop(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta);
double get_eps_value() const { return epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
double epsilon;
};
} }
} }
...@@ -62,6 +62,11 @@ namespace ngraph ...@@ -62,6 +62,11 @@ namespace ngraph
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
virtual NodeVector get_input_ops() override
{
return NodeVector{get_inputs().at(0).get_output().get_node()};
}
protected: protected:
size_t m_n; size_t m_n;
}; };
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "ngraph/runtime/cpu/cpu_emitter.hpp" #include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <iostream>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
...@@ -370,6 +369,117 @@ namespace ngraph ...@@ -370,6 +369,117 @@ namespace ngraph
writer << "}\n"; writer << "}\n";
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormBackprop)
{
const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node);
auto gamma_shape = args[0].get_shape();
auto beta_shape = args[1].get_shape();
auto input_shape = args[2].get_shape();
auto mean_shape = args[3].get_shape();
auto variance_shape = args[4].get_shape();
auto delta_shape = args[5].get_shape();
auto result_shape = out[0].get_shape();
// get input element type
const string& et =
mkldnn_utils::get_mkldnn_data_type_string(args[2].get_element_type());
writer << "{\n";
writer.indent++;
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(" << input_shape[1] * 2 << ");\n";
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">vdiff_weights(" << input_shape[1] * 2 << ");\n";
auto weights_shape = Shape{2, input_shape[1]};
// push gamma and beta
writer << "auto gamma = " << args[0].get_name() << ";\n";
writer << "auto beta = " << args[1].get_name() << ";\n";
writer << "memcpy(&bn_weights[0], gamma,"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[1].get_size() << ", beta, "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
// get the eps value from the bn node
writer << "auto epsilon = " << batchnorm->get_eps_value() << ";\n";
// Bind to CPU engine
writer << "using namespace mkldnn; \n";
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
// create memory descriptors
writer << "memory::desc input_data_desc = memory::desc({" << join(input_shape)
<< "}, " << et << ", memory::format::nchw);\n";
// TODO define weights by stacking gamma and beta values
writer << "memory::desc weights_desc = memory::desc({" << join(weights_shape)
<< "}, " << et << ", memory::format::nc);\n";
writer << "memory::desc diff_weights_desc = memory::desc({" << join(weights_shape)
<< "}, " << et << ", memory::format::nc);\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape) << "}, "
<< et << ", memory::format::nchw);\n";
writer << "memory::desc mean_desc = memory::desc({" << join(mean_shape) << "}, "
<< et << ", memory::format::x);\n";
writer << "memory::desc variance_desc = memory::desc({" << join(variance_shape)
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc delta_desc = memory::desc({" << join(input_shape) << "}, "
<< et << ", memory::format::nchw);\n";
// Define memory for the user data
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[2].get_name() << ");\n";
writer << "memory weights = memory({weights_desc, cpu_engine}, bn_weights.data()"
<< ");\n";
writer << "memory diff_weights = memory({diff_weights_desc, cpu_engine}, "
"vdiff_weights.data()"
<< ");\n";
writer << "memory mean = memory({mean_desc, cpu_engine}, " << args[3].get_name()
<< ");\n";
writer << "memory variance = memory({variance_desc, cpu_engine}, "
<< args[4].get_name() << ");\n";
writer << "memory delta = memory({delta_desc, cpu_engine}, " << args[5].get_name()
<< ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, " << out[0].get_name()
<< ");\n";
//create fprop batchnorm descriptor
writer << "batch_normalization_forward::desc bn_fprop_desc = "
"batch_normalization_forward::desc(forward_training,"
<< "input_data_desc, epsilon, use_scale_shift);\n";
//bn fprop primitive descriptor
writer
<< "batch_normalization_forward::primitive_desc bn_fprop_prim_desc = "
"batch_normalization_forward::primitive_desc(bn_fprop_desc, cpu_engine);\n";
//create bprop batchnorm descriptor
writer << "batch_normalization_backward::desc bn_bprop_desc = "
"batch_normalization_backward::desc(backward, delta_desc, "
"input_data_desc, epsilon, use_scale_shift);\n";
//bn bprop primitive descriptor
writer << "batch_normalization_backward::primitive_desc bn_bprop_prim_desc = "
"batch_normalization_backward::primitive_desc(bn_bprop_desc, cpu_engine, "
"bn_fprop_prim_desc);\n";
//create a batchnorm fprop primitive
writer << " batch_normalization_backward bn_bprop = "
"batch_normalization_backward(bn_bprop_prim_desc, input_data, mean, "
"variance, delta, weights, result, diff_weights);\n ";
//create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({bn_bprop}).wait();\n";
writer << "memcpy(" << out[1].get_name() << ",&vdiff_weights[0],"
<< args[1].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(" << out[2].get_name() << ",&vdiff_weights[0] + "
<< args[1].get_size() << ","
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dot) void CPU_Emitter::EMITTER_DECL(ngraph::op::Dot)
{ {
......
...@@ -179,6 +179,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -179,6 +179,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Concat), &runtime::cpu::CPU_Emitter::emit<op::Concat>}, {TI(ngraph::op::Concat), &runtime::cpu::CPU_Emitter::emit<op::Concat>},
{TI(ngraph::op::Divide), &runtime::cpu::CPU_Emitter::emit<op::Divide>}, {TI(ngraph::op::Divide), &runtime::cpu::CPU_Emitter::emit<op::Divide>},
{TI(ngraph::op::Equal), &runtime::cpu::CPU_Emitter::emit<op::Equal>}, {TI(ngraph::op::Equal), &runtime::cpu::CPU_Emitter::emit<op::Equal>},
{TI(ngraph::op::GetOutputElement), &runtime::cpu::CPU_Emitter::emit<op::GetOutputElement>},
{TI(ngraph::op::Greater), &runtime::cpu::CPU_Emitter::emit<op::Greater>}, {TI(ngraph::op::Greater), &runtime::cpu::CPU_Emitter::emit<op::Greater>},
{TI(ngraph::op::GreaterEq), &runtime::cpu::CPU_Emitter::emit<op::GreaterEq>}, {TI(ngraph::op::GreaterEq), &runtime::cpu::CPU_Emitter::emit<op::GreaterEq>},
{TI(ngraph::op::Less), &runtime::cpu::CPU_Emitter::emit<op::Less>}, {TI(ngraph::op::Less), &runtime::cpu::CPU_Emitter::emit<op::Less>},
...@@ -231,6 +232,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -231,6 +232,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::AvgPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::AvgPoolBackprop>}, {TI(ngraph::op::AvgPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::AvgPoolBackprop>},
{TI(ngraph::op::Pad), &runtime::cpu::CPU_Emitter::emit<op::Pad>}, {TI(ngraph::op::Pad), &runtime::cpu::CPU_Emitter::emit<op::Pad>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::emit<op::BatchNorm>}, {TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::emit<op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>}, {TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::Product), &runtime::cpu::CPU_Emitter::emit<op::Product>}, {TI(ngraph::op::Product), &runtime::cpu::CPU_Emitter::emit<op::Product>},
{TI(ngraph::op::Max), &runtime::cpu::CPU_Emitter::emit<op::Max>}, {TI(ngraph::op::Max), &runtime::cpu::CPU_Emitter::emit<op::Max>},
......
...@@ -43,6 +43,7 @@ static const std::unordered_set<std::type_index> s_op_registry{ ...@@ -43,6 +43,7 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::AvgPool), TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop), TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm), TI(ngraph::op::BatchNorm),
TI(ngraph::op::BatchNormBackprop),
TI(ngraph::op::Convolution), TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData), TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters), TI(ngraph::op::ConvolutionBackpropFilters),
......
...@@ -75,7 +75,7 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions( ...@@ -75,7 +75,7 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions(
} }
else else
{ {
new_args.push_back(node->get_input_op(index)); new_args.push_back(output.get_node());
} }
index++; index++;
} }
...@@ -163,7 +163,7 @@ void runtime::cpu::pass::CPULayout::set_default_layouts( ...@@ -163,7 +163,7 @@ void runtime::cpu::pass::CPULayout::set_default_layouts(
} }
else else
{ {
new_args.push_back(node->get_input_op(index)); new_args.push_back(output.get_node());
} }
index++; index++;
} }
......
...@@ -327,6 +327,12 @@ static shared_ptr<ngraph::Function> ...@@ -327,6 +327,12 @@ static shared_ptr<ngraph::Function>
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2], args[3], args[4]); node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2], args[3], args[4]);
} }
else if (node_op == "BatchNormBackprop")
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNormBackprop>(
epsilon, args[0], args[1], args[2], args[3], args[4], args[5]);
}
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
...@@ -482,10 +488,10 @@ static shared_ptr<ngraph::Function> ...@@ -482,10 +488,10 @@ static shared_ptr<ngraph::Function>
shared_ptr<Function> f_ptr = function_map.at(function_name); shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::FunctionCall>(f_ptr, args); node = make_shared<op::FunctionCall>(f_ptr, args);
} }
// else if (node_op == "GetOutputElement") else if (node_op == "GetOutputElement")
// { {
// node = make_shared<op::GetOutputElement>(args[0]); node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>());
// } }
else if (node_op == "Greater") else if (node_op == "Greater")
{ {
node = make_shared<op::Greater>(args[0], args[1]); node = make_shared<op::Greater>(args[0], args[1]);
...@@ -835,6 +841,11 @@ static json write(const Node& n) ...@@ -835,6 +841,11 @@ static json write(const Node& n)
auto tmp = dynamic_cast<const op::BatchNorm*>(&n); auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value(); node["eps"] = tmp->get_eps_value();
} }
else if (node_op == "BatchNormBackprop")
{
auto tmp = dynamic_cast<const op::BatchNormBackprop*>(&n);
node["eps"] = tmp->get_eps_value();
}
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
auto tmp = dynamic_cast<const op::Broadcast*>(&n); auto tmp = dynamic_cast<const op::Broadcast*>(&n);
...@@ -919,6 +930,8 @@ static json write(const Node& n) ...@@ -919,6 +930,8 @@ static json write(const Node& n)
} }
else if (node_op == "GetOutputElement") else if (node_op == "GetOutputElement")
{ {
auto tmp = dynamic_cast<const op::GetOutputElement*>(&n);
node["n"] = tmp->get_n();
} }
else if (node_op == "Greater") else if (node_op == "Greater")
{ {
......
...@@ -308,3 +308,97 @@ TEST(cpu_fusion, unhandled_op) ...@@ -308,3 +308,97 @@ TEST(cpu_fusion, unhandled_op)
auto external = manager->compile(f); auto external = manager->compile(f);
ASSERT_THROW(backend->make_call_frame(external), ngraph_error); ASSERT_THROW(backend->make_call_frame(external), ngraph_error);
} }
TEST(cpu_fusion, bn_bprop_n4c3h2w2)
{
auto input_shape = Shape{4, 3, 2, 2};
auto shape_mean = Shape{3};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{3};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{3};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{3};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{3};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{4, 3, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto manager = runtime::Manager::get("CPU");
auto backend = manager->allocate_backend();
auto _input = backend->make_primary_tensor_view(element::f32, input_shape);
vector<float> dataInput{
10.76331902f, 11.51178265f, 10.31018162f, 12.2993021f, 14.17626667f, 14.63498497f,
13.63494492f, 13.84248161f, 11.34602547f, 13.22014618f, 10.46686649f, 10.39842987f,
12.94806862f, 11.71670246f, 14.94438076f, 13.13236618f, 13.40889645f, 12.76128387f,
11.34430027f, 11.86629677f, 11.11464024f, 10.93221283f, 11.95324039f, 10.96581173f,
13.05455494f, 14.41404247f, 13.11169434f, 11.26559448f, 10.89965153f, 14.08202171f,
11.12685776f, 12.58428574f, 12.59247875f, 13.00187492f, 12.66310215f, 10.06655025f,
12.62048626f, 14.47942352f, 13.84950638f, 10.61425877f, 11.47936344f, 13.06011772f,
13.63069057f, 12.31748772f, 13.84555244f, 10.95815468f, 12.78933334f, 12.75389099f};
copy_data(_input, dataInput);
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{12.56472874f, 12.80312157f, 11.81676865f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{1.94557643f, 1.32772446f, 1.28163588f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{2.0f, 2.0f, 2.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{1.0f, 1.0f, 1.0f});
auto result = backend->make_primary_tensor_view(element::f32, shape_r);
shared_ptr<runtime::TensorView> _delta =
backend->make_primary_tensor_view(element::f32, shape_r);
vector<float> deltaData(shape_size(shape_r), 20.0f);
copy_data(_delta, deltaData);
auto f = make_shared<Function>(bn, op::ParameterVector{mean, var, input, gamma, beta});
auto C = std::make_shared<op::Parameter>(element::f32, shape_r);
auto dinput = bn->backprop_node(input, C);
auto dgamma = bn->backprop_node(gamma, C);
auto dbeta = bn->backprop_node(beta, C);
auto df = make_shared<Function>(NodeVector{dinput, dgamma, dbeta},
op::ParameterVector{mean, var, input, gamma, beta, C});
//roundtrip serialization
string js = serialize(df, 4);
istringstream in(js);
df = deserialize(in);
auto external = manager->compile(df);
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> _dinput =
backend->make_primary_tensor_view(element::f32, shape_r);
shared_ptr<runtime::TensorView> _dgamma =
backend->make_primary_tensor_view(element::f32, gamma_shape);
shared_ptr<runtime::TensorView> _dbeta =
backend->make_primary_tensor_view(element::f32, beta_shape);
cf->call({_mean, _var, _input, _gamma, _beta, _delta}, {_dinput, _dgamma, _dbeta});
vector<float> expected_input{
8.17051607e-06f, 4.77576657e-06f, 1.02257760e-05f, 1.20387525e-06f, -1.73868522e-06f,
3.84632768e-06f, -1.07932050e-05f, -2.57458956e-06f, -2.22166714e-06f, -8.38779043e-06f,
-2.48082982e-06f, 5.89238360e-06f, -2.52895109e-07f, -8.68433445e-06f, -5.82726737e-06f,
8.84659658e-06f, 3.03944108e-05f, 4.05480879e-05f, 1.84123158e-05f, 2.30061178e-05f,
1.34087590e-05f, -9.26072571e-07f, -3.22908454e-05f, -2.07365116e-05f, -4.21330941e-05f,
2.83083100e-05f, -3.71039101e-05f, -4.84390640e-06f, -2.93012376e-05f, 5.68858087e-06f,
1.83181458e-05f, -1.07494506e-05f, -2.32429103e-06f, 6.92914809e-06f, -6.66512321e-06f,
-7.00302840e-06f, -3.46675184e-06f, -4.36748381e-06f, 6.73822226e-07f, -4.20158993e-06f,
3.83005061e-06f, 5.85143729e-06f, 4.17875243e-06f, -8.64167783e-06f, 1.00170803e-05f,
-4.23939666e-06f, 4.80201680e-06f, 4.62702078e-06f};
ASSERT_TRUE(ngraph::test::all_close(read_vector<float>(_dinput), expected_input, 1e-3f, 1e-4f));
vector<float> expected_dgamma{7.06315041e-05f, -2.35289335e-04f, -5.06639481e-05f};
ASSERT_TRUE(
ngraph::test::all_close(read_vector<float>(_dgamma), expected_dgamma, 1e-2f, 1e-3f));
vector<float> expected_dbeta{320.f, 320.f, 320.f};
ASSERT_TRUE(ngraph::test::all_close(read_vector<float>(_dbeta), expected_dbeta, 1e-4f, 1e-8f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment