numeric_compare.hpp 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
adstraw's avatar
adstraw committed
16

17
#include "ngraph/log.hpp"
18
#include "ngraph/type/element_type.hpp"
adstraw's avatar
adstraw committed
19 20 21
#include "util/all_close.hpp"
#include "util/autodiff/backprop_derivative.hpp"
#include "util/autodiff/numeric_derivative.hpp"
22
#include "util/test_tools.hpp"
adstraw's avatar
adstraw committed
23

24 25 26
// TODO: Consider removing template since only <float> is being used in tests and numerical
//       derivative does not work with int types
// TODO: Always compute the numerical derivatives in double
adstraw's avatar
adstraw committed
27
template <typename T>
28
bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& backend,
29 30
                              std::shared_ptr<ngraph::Function> f,
                              std::shared_ptr<ngraph::Function> g,
adstraw's avatar
adstraw committed
31 32 33 34
                              const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
                              T rtol,
                              T atol)
{
35
    T delta = static_cast<T>(0.0009765625f); // Binary-representable number near 0.001
36 37

    // Use INTERPRETER to compute numerical derivatives
38
    auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
39 40 41 42

    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
    for (auto arg : args)
    {
43
        auto interpreter_arg = interpreter_backend->create_tensor(
44 45 46 47 48 49 50 51 52 53 54 55 56
            arg->get_tensor().get_element_type(), arg->get_shape());

        // TODO: copy_data should not require T. Quick fix here for bool used in `Select`
        if (arg->get_tensor().get_element_type() == ngraph::element::boolean)
        {
            copy_data(interpreter_arg, read_vector<char>(arg));
        }
        else
        {
            copy_data(interpreter_arg, read_vector<T>(arg));
        }
        interpreter_args.push_back(interpreter_arg);
    }
adstraw's avatar
adstraw committed
57
    auto results_num = ngraph::autodiff::numeric_derivative<T>(
58
        interpreter_backend, f, interpreter_args, delta, f->get_parameters());
adstraw's avatar
adstraw committed
59

60
    // Use the backend being tested to compute symbolic derivatives
adstraw's avatar
adstraw committed
61
    auto results_sym =
62
        ngraph::autodiff::backprop_derivative<T>(backend, g, args, g->get_parameters());
adstraw's avatar
adstraw committed
63

64 65 66 67
    // Cast to HostTensorView for comparision
    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_results_sym;
    for (auto result : results_sym)
    {
68 69
        auto interpreter_result =
            interpreter_backend->create_tensor(ngraph::element::from<T>(), result->get_shape());
70 71 72 73 74
        copy_data(interpreter_result, read_vector<T>(result));
        interpreter_results_sym.push_back(interpreter_result);
    }

    return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
adstraw's avatar
adstraw committed
75 76
}

77 78 79 80 81 82 83 84 85 86
template <typename T>
bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& backend,
                              std::function<std::shared_ptr<ngraph::Function>()> make_graph,
                              const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
                              T rtol,
                              T atol)
{
    return autodiff_numeric_compare(backend, make_graph(), make_graph(), args, rtol, atol);
}

adstraw's avatar
adstraw committed
87 88 89
template <typename T>
bool autodiff_numeric_compare_selective(
    const std::shared_ptr<ngraph::runtime::Backend>& backend,
90 91
    std::shared_ptr<ngraph::Function> f,
    std::shared_ptr<ngraph::Function> g,
adstraw's avatar
adstraw committed
92 93 94 95 96
    const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
    T rtol,
    T atol,
    const std::vector<bool>& indep_param_mask)
{
97
    // Use INTERPRETER to compute numerical derivatives
adstraw's avatar
adstraw committed
98 99 100 101 102 103 104 105 106 107 108 109 110
    std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params;

    size_t i = 0;

    for (auto b : indep_param_mask)
    {
        if (b)
        {
            f_indep_params.push_back(f->get_parameters().at(i));
        }
        i++;
    }

111
    auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
adstraw's avatar
adstraw committed
112

113 114 115
    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
    for (auto arg : args)
    {
116
        auto interpreter_arg = interpreter_backend->create_tensor(
117 118 119 120 121 122 123 124 125 126 127 128 129 130
            arg->get_tensor().get_element_type(), arg->get_shape());

        // TODO: copy_data should not require T. Quick fix here for bool used in `Select`
        if (arg->get_tensor().get_element_type() == ngraph::element::boolean)
        {
            copy_data(interpreter_arg, read_vector<char>(arg));
        }
        else
        {
            copy_data(interpreter_arg, read_vector<T>(arg));
        }
        interpreter_args.push_back(interpreter_arg);
    }
    auto results_num = ngraph::autodiff::numeric_derivative<T>(
131
        interpreter_backend, f, interpreter_args, .001f, f_indep_params);
132 133

    // Use the backend being tested to compute symbolic derivatives
adstraw's avatar
adstraw committed
134 135 136 137 138 139 140 141 142 143 144 145 146
    std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params;

    i = 0;

    for (auto b : indep_param_mask)
    {
        if (b)
        {
            g_indep_params.push_back(g->get_parameters().at(i));
        }
        i++;
    }

147
    auto results_sym = ngraph::autodiff::backprop_derivative<T>(backend, g, args, g_indep_params);
adstraw's avatar
adstraw committed
148

149 150 151 152
    // Cast to HostTensorView for comparision
    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_results_sym;
    for (auto result : results_sym)
    {
153 154
        auto interpreter_result =
            interpreter_backend->create_tensor(ngraph::element::from<T>(), result->get_shape());
155 156 157 158 159
        copy_data(interpreter_result, read_vector<T>(result));
        interpreter_results_sym.push_back(interpreter_result);
    }

    return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
adstraw's avatar
adstraw committed
160
}
161 162 163 164 165 166 167 168 169 170 171 172 173

template <typename T>
bool autodiff_numeric_compare_selective(
    const std::shared_ptr<ngraph::runtime::Backend>& backend,
    std::function<std::shared_ptr<ngraph::Function>()> make_graph,
    const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
    T rtol,
    T atol,
    const std::vector<bool>& indep_param_mask)
{
    return autodiff_numeric_compare_selective(
        backend, make_graph(), make_graph(), args, rtol, atol, indep_param_mask);
}