Commit c6a0fae3 authored by varun-intel's avatar varun-intel Committed by Scott Cyphers

Interpreter implementation of batch norm bprop (#934)

* updated

* type prop

* disable test in manifest

* try to exclude

* style

* double

* dobule

* more

* style

* more

* vecs

* fix goe
parent 2a64baca
......@@ -14,6 +14,8 @@
* limitations under the License.
*******************************************************************************/
#include <set>
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -267,7 +269,7 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
if (this->get_training_flag() && get_input_size() == 3)
{
auto goes = op::get_output_elements(*this);
auto goes = op::get_output_elements(shared_from_this());
mean = goes.at(1);
var = goes.at(2);
}
......
......@@ -54,14 +54,14 @@ void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta, get_n());
}
NodeVector op::get_output_elements(const Node& mon)
NodeVector op::get_output_elements(const shared_ptr<Node>& mon)
{
NodeVector goes(mon.get_outputs().size());
NodeVector goes;
for (auto goe_input : mon.get_output_inputs(0))
for (size_t i = 0; i < mon->get_outputs().size(); i++)
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
auto goe = make_shared<GetOutputElement>(mon, i);
goes.push_back(std::static_pointer_cast<Node>(goe));
}
return goes;
}
......@@ -22,7 +22,7 @@ namespace ngraph
{
namespace op
{
NodeVector get_output_elements(const Node& mon);
NodeVector get_output_elements(const std::shared_ptr<Node>& mon);
/// \brief Operation to get an output from a node.
class GetOutputElement : public Node
......
......@@ -10,3 +10,4 @@ one_hot_vector_1_barely_oob
one_hot_vector_1_far_oob
one_hot_vector_1_fp
one_hot_vector_1_fp_nonint
backwards_batch_norm_three_outputs
......@@ -2,6 +2,7 @@
abc_int64
batch_norm_one_output
batch_norm_three_outputs
backwards_batch_norm_three_outputs
#need to check
computation_reuse
#int64 is not supprted
......
......@@ -266,6 +266,22 @@ private:
args[2]->get_shape());
}
}
else if (node_op == "BatchNormBackprop")
{
ngraph::op::BatchNormBackprop* bn_bprop =
dynamic_cast<ngraph::op::BatchNormBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(args[3]->get_data_ptr()),
reinterpret_cast<T*>(args[4]->get_data_ptr()),
reinterpret_cast<T*>(args[5]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
reinterpret_cast<T*>(out[1]->get_data_ptr()),
reinterpret_cast<T*>(out[2]->get_data_ptr()),
args[2]->get_shape());
}
else if (node_op == "AvgPoolBackprop")
{
op::AvgPoolBackprop* apb = dynamic_cast<op::AvgPoolBackprop*>(&node);
......
......@@ -18,9 +18,15 @@
#include <cmath>
#include <iostream>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/util.hpp"
namespace ngraph
......@@ -30,14 +36,16 @@ namespace ngraph
namespace reference
{
template <typename T>
void batch_norm_three_outputs(double eps,
const T* arg0,
const T* arg1,
const T* arg2,
T* out0,
T* out1,
T* out2,
const Shape& arg2_shape)
void batch_norm_three_outputs_with_intermediates(double eps,
const T* arg0,
const T* arg1,
const T* arg2,
T* out0,
T* out1,
T* out2,
T* out3,
T* out4,
const Shape& arg2_shape)
{
auto eps_casted = static_cast<T>(eps);
auto channels = arg2_shape[1];
......@@ -85,13 +93,38 @@ namespace ngraph
auto channel_beta = arg1[c];
auto input_index = arg2_transform.index(arg2_coord);
auto normalized = (arg2[input_index] - channel_mean) /
(std::sqrt(channel_var + eps_casted));
out0[input_index] = normalized * channel_gamma + channel_beta;
out3[input_index] = arg2[input_index] - channel_mean;
out4[input_index] =
out3[input_index] / (std::sqrt(channel_var + eps_casted));
out0[input_index] = out4[input_index] * channel_gamma + channel_beta;
}
}
}
template <typename T>
void batch_norm_three_outputs(double eps,
const T* arg0,
const T* arg1,
const T* arg2,
T* out0,
T* out1,
T* out2,
const Shape& arg2_shape)
{
std::vector<T> centered(shape_size(arg2_shape));
std::vector<T> normalized(shape_size(arg2_shape));
batch_norm_three_outputs_with_intermediates(eps,
arg0,
arg1,
arg2,
out0,
out1,
out2,
centered.data(),
normalized.data(),
arg2_shape);
}
template <typename T>
void batch_norm_one_output(double eps,
const T* arg0,
......@@ -119,6 +152,92 @@ namespace ngraph
out0[input_index] = normalized * channel_gamma + channel_beta;
}
}
template <typename T>
void batch_norm_backprop(double eps,
const T* arg0,
const T* arg1,
const T* arg2,
const T* arg3,
const T* arg4,
const T* arg5,
T* out0,
T* out1,
T* out2,
const Shape& arg2_shape)
{
auto eps_casted = static_cast<T>(eps);
Shape mean_shape{arg2_shape[1]};
AxisSet reduction_axes;
for (size_t idx = 0; idx < arg2_shape.size(); idx++)
{
if (idx != 1)
{
reduction_axes.insert(idx);
}
}
auto arg2_num_elements = shape_size(arg2_shape);
auto mean_num_elements = shape_size(mean_shape);
auto reduction_axes_size = arg2_num_elements / mean_num_elements;
// Compute the mean, variance, and normalized values
std::vector<T> bn_output(arg2_num_elements);
std::vector<T> centered(arg2_num_elements);
std::vector<T> normalized(arg2_num_elements);
std::vector<T> mean(mean_num_elements);
std::vector<T> variance(mean_num_elements);
std::vector<T> stddev(mean_num_elements);
batch_norm_three_outputs_with_intermediates(eps,
arg0,
arg1,
arg2,
bn_output.data(),
mean.data(),
variance.data(),
centered.data(),
normalized.data(),
arg2_shape);
for (size_t i = 0; i < mean_num_elements; i++)
{
stddev[i] = std::sqrt(variance[i] + eps_casted);
}
// Broadcast gamma and the standard deviation
std::vector<T> gamma_bcast(arg2_num_elements);
std::vector<T> stddev_bcast(arg2_num_elements);
broadcast(arg0, gamma_bcast.data(), mean_shape, arg2_shape, reduction_axes);
broadcast(
stddev.data(), stddev_bcast.data(), mean_shape, arg2_shape, reduction_axes);
// Bprop into gamma
std::vector<T> delta_times_normalized(arg2_num_elements);
multiply(normalized.data(), arg5, delta_times_normalized.data(), arg2_num_elements);
sum(delta_times_normalized.data(), out1, arg2_shape, mean_shape, reduction_axes);
// Bprop into beta
sum(arg5, out2, arg2_shape, mean_shape, reduction_axes);
// // Broadcast the gamma and beta grads
std::vector<T> delta_gamma_bcast(arg2_num_elements);
broadcast(out1, delta_gamma_bcast.data(), mean_shape, arg2_shape, reduction_axes);
std::vector<T> delta_beta_bcast(arg2_num_elements);
broadcast(out2, delta_beta_bcast.data(), mean_shape, arg2_shape, reduction_axes);
// Bprop into the input
for (size_t i = 0; i < arg2_num_elements; i++)
{
auto scale_normalized = gamma_bcast[i] / stddev_bcast[i];
out0[i] = static_cast<T>(
scale_normalized *
(arg5[i] -
(normalized[i] * delta_gamma_bcast[i] + delta_beta_bcast[i]) /
reduction_axes_size));
}
}
}
}
}
......@@ -1584,6 +1584,32 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected);
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_three_outputs)
{
auto shape_in = Shape{2, 3, 1, 1};
auto shape_mean = Shape{3};
auto make_graph = [shape_in, shape_mean] {
auto A = make_shared<op::Parameter>(element::f64, shape_in);
auto B = make_shared<op::Parameter>(element::f64, shape_mean);
auto C = make_shared<op::Parameter>(element::f64, shape_mean);
auto BN = make_shared<op::BatchNorm>(1e-3, B, C, A);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(BN, 0),
op::ParameterVector{A, B, C});
return f;
};
auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<double> rng(-1.0, 1.0);
auto x0 = rng.initialize(backend->create_tensor<double>(shape_in));
auto x1 = rng.initialize(backend->create_tensor<double>(shape_mean));
auto x2 = rng.initialize(backend->create_tensor<double>(shape_mean));
EXPECT_TRUE(autodiff_numeric_compare<double>(backend, make_graph, {x0, x1, x2}, .01, .01));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......
......@@ -208,8 +208,8 @@ namespace ngraph
auto clone_bwd = s_clone_bwd_map[f];
auto cache_dfdx = get_autodiff<T>(backend, clone_bwd, mod_df_input_args, indep_params);
const auto numpy_atol = 1e-5f;
const auto numpy_rtol = 1e-8f;
const T numpy_atol = static_cast<const T>(1e-5f);
const T numpy_rtol = static_cast<const T>(1e-8f);
auto close = ngraph::test::all_close<T>(dfdx, cache_dfdx, numpy_atol, numpy_rtol);
if (!close)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment