numeric_compare.hpp 6.53 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
adstraw's avatar
adstraw committed
16

17
#include "ngraph/log.hpp"
18
#include "ngraph/type/element_type.hpp"
adstraw's avatar
adstraw committed
19 20 21
#include "util/all_close.hpp"
#include "util/autodiff/backprop_derivative.hpp"
#include "util/autodiff/numeric_derivative.hpp"
22
#include "util/test_tools.hpp"
adstraw's avatar
adstraw committed
23

24 25 26
// TODO: Consider removing template since only <float> is being used in tests and numerical
//       derivative does not work with int types
// TODO: Always compute the numerical derivatives in double
adstraw's avatar
adstraw committed
27
template <typename T>
28 29 30 31 32 33 34
::testing::AssertionResult
    autodiff_numeric_compare(ngraph::runtime::Backend* backend,
                             std::shared_ptr<ngraph::Function> f,
                             std::shared_ptr<ngraph::Function> g,
                             const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
                             T rtol,
                             T atol)
adstraw's avatar
adstraw committed
35
{
36
    T delta = static_cast<T>(0.0009765625f); // Binary-representable number near 0.001
37 38

    // Use INTERPRETER to compute numerical derivatives
39
    auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
40

41
    std::vector<std::shared_ptr<ngraph::runtime::Tensor>> interpreter_args;
42 43
    for (auto arg : args)
    {
44 45
        auto interpreter_arg =
            interpreter_backend->create_tensor(arg->get_element_type(), arg->get_shape());
46 47

        // TODO: copy_data should not require T. Quick fix here for bool used in `Select`
48
        if (arg->get_element_type() == ngraph::element::boolean)
49 50 51 52 53 54 55 56 57
        {
            copy_data(interpreter_arg, read_vector<char>(arg));
        }
        else
        {
            copy_data(interpreter_arg, read_vector<T>(arg));
        }
        interpreter_args.push_back(interpreter_arg);
    }
adstraw's avatar
adstraw committed
58
    auto results_num = ngraph::autodiff::numeric_derivative<T>(
59
        interpreter_backend.get(), f, interpreter_args, delta, f->get_parameters());
adstraw's avatar
adstraw committed
60

61
    // Use the backend being tested to compute symbolic derivatives
adstraw's avatar
adstraw committed
62
    auto results_sym =
63
        ngraph::autodiff::backprop_derivative<T>(backend, g, args, g->get_parameters());
adstraw's avatar
adstraw committed
64

65 66
    // Cast to HostTensor for comparision
    std::vector<std::shared_ptr<ngraph::runtime::Tensor>> interpreter_results_sym;
67 68
    for (auto result : results_sym)
    {
69 70
        auto interpreter_result =
            interpreter_backend->create_tensor(ngraph::element::from<T>(), result->get_shape());
71 72 73 74 75
        copy_data(interpreter_result, read_vector<T>(result));
        interpreter_results_sym.push_back(interpreter_result);
    }

    return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
adstraw's avatar
adstraw committed
76 77
}

78
template <typename T>
79 80 81 82 83 84
::testing::AssertionResult
    autodiff_numeric_compare(ngraph::runtime::Backend* backend,
                             std::function<std::shared_ptr<ngraph::Function>()> make_graph,
                             const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
                             T rtol,
                             T atol)
85 86 87 88
{
    return autodiff_numeric_compare(backend, make_graph(), make_graph(), args, rtol, atol);
}

adstraw's avatar
adstraw committed
89
template <typename T>
90
::testing::AssertionResult autodiff_numeric_compare_selective(
91
    ngraph::runtime::Backend* backend,
92 93
    std::shared_ptr<ngraph::Function> f,
    std::shared_ptr<ngraph::Function> g,
94
    const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
adstraw's avatar
adstraw committed
95 96 97 98
    T rtol,
    T atol,
    const std::vector<bool>& indep_param_mask)
{
99
    // Use INTERPRETER to compute numerical derivatives
adstraw's avatar
adstraw committed
100 101 102 103 104 105 106 107 108 109 110 111 112
    std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params;

    size_t i = 0;

    for (auto b : indep_param_mask)
    {
        if (b)
        {
            f_indep_params.push_back(f->get_parameters().at(i));
        }
        i++;
    }

113
    auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
adstraw's avatar
adstraw committed
114

115
    std::vector<std::shared_ptr<ngraph::runtime::Tensor>> interpreter_args;
116 117
    for (auto arg : args)
    {
118 119
        auto interpreter_arg =
            interpreter_backend->create_tensor(arg->get_element_type(), arg->get_shape());
120 121

        // TODO: copy_data should not require T. Quick fix here for bool used in `Select`
122
        if (arg->get_element_type() == ngraph::element::boolean)
123 124 125 126 127 128 129 130 131 132
        {
            copy_data(interpreter_arg, read_vector<char>(arg));
        }
        else
        {
            copy_data(interpreter_arg, read_vector<T>(arg));
        }
        interpreter_args.push_back(interpreter_arg);
    }
    auto results_num = ngraph::autodiff::numeric_derivative<T>(
133
        interpreter_backend.get(), f, interpreter_args, .001f, f_indep_params);
134 135

    // Use the backend being tested to compute symbolic derivatives
adstraw's avatar
adstraw committed
136 137 138 139 140 141 142 143 144 145 146 147 148
    std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params;

    i = 0;

    for (auto b : indep_param_mask)
    {
        if (b)
        {
            g_indep_params.push_back(g->get_parameters().at(i));
        }
        i++;
    }

149
    auto results_sym = ngraph::autodiff::backprop_derivative<T>(backend, g, args, g_indep_params);
adstraw's avatar
adstraw committed
150

151 152
    // Cast to HostTensor for comparision
    std::vector<std::shared_ptr<ngraph::runtime::Tensor>> interpreter_results_sym;
153 154
    for (auto result : results_sym)
    {
155 156
        auto interpreter_result =
            interpreter_backend->create_tensor(ngraph::element::from<T>(), result->get_shape());
157 158 159 160 161
        copy_data(interpreter_result, read_vector<T>(result));
        interpreter_results_sym.push_back(interpreter_result);
    }

    return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
adstraw's avatar
adstraw committed
162
}
163 164

template <typename T>
165
::testing::AssertionResult autodiff_numeric_compare_selective(
166
    ngraph::runtime::Backend* backend,
167
    std::function<std::shared_ptr<ngraph::Function>()> make_graph,
168
    const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& args,
169 170 171 172 173 174 175
    T rtol,
    T atol,
    const std::vector<bool>& indep_param_mask)
{
    return autodiff_numeric_compare_selective(
        backend, make_graph(), make_graph(), args, rtol, atol, indep_param_mask);
}