numeric_compare.hpp 5.73 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
adstraw's avatar
adstraw committed
16

17
#include "ngraph/log.hpp"
18
#include "ngraph/type/element_type.hpp"
adstraw's avatar
adstraw committed
19 20 21
#include "util/all_close.hpp"
#include "util/autodiff/backprop_derivative.hpp"
#include "util/autodiff/numeric_derivative.hpp"
22
#include "util/test_tools.hpp"
adstraw's avatar
adstraw committed
23

24 25 26
// TODO: Consider removing template since only <float> is being used in tests and numerical
//       derivative does not work with int types
// TODO: Always compute the numerical derivatives in double
adstraw's avatar
adstraw committed
27
template <typename T>
28
bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Backend>& backend,
adstraw's avatar
adstraw committed
29 30 31 32 33
                              std::function<std::shared_ptr<ngraph::Function>()> make_graph,
                              const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
                              T rtol,
                              T atol)
{
34
    T delta = static_cast<T>(0.0009765625f); // Binary-representable number near 0.001
35 36

    // Use INTERPRETER to compute numerical derivatives
37
    auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
adstraw's avatar
adstraw committed
38
    auto f = make_graph();
39 40 41 42

    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
    for (auto arg : args)
    {
43
        auto interpreter_arg = interpreter_backend->create_tensor(
44 45 46 47 48 49 50 51 52 53 54 55 56
            arg->get_tensor().get_element_type(), arg->get_shape());

        // TODO: copy_data should not require T. Quick fix here for bool used in `Select`
        if (arg->get_tensor().get_element_type() == ngraph::element::boolean)
        {
            copy_data(interpreter_arg, read_vector<char>(arg));
        }
        else
        {
            copy_data(interpreter_arg, read_vector<T>(arg));
        }
        interpreter_args.push_back(interpreter_arg);
    }
adstraw's avatar
adstraw committed
57
    auto results_num = ngraph::autodiff::numeric_derivative<T>(
58
        interpreter_backend, f, interpreter_args, delta, f->get_parameters());
adstraw's avatar
adstraw committed
59

60
    // Use the backend being tested to compute symbolic derivatives
adstraw's avatar
adstraw committed
61 62
    auto g = make_graph();
    auto results_sym =
63
        ngraph::autodiff::backprop_derivative<T>(backend, g, args, g->get_parameters());
adstraw's avatar
adstraw committed
64

65 66 67 68
    // Cast to HostTensorView for comparision
    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_results_sym;
    for (auto result : results_sym)
    {
69 70
        auto interpreter_result =
            interpreter_backend->create_tensor(ngraph::element::from<T>(), result->get_shape());
71 72 73 74 75
        copy_data(interpreter_result, read_vector<T>(result));
        interpreter_results_sym.push_back(interpreter_result);
    }

    return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
adstraw's avatar
adstraw committed
76 77 78 79 80 81 82 83 84 85 86
}

template <typename T>
bool autodiff_numeric_compare_selective(
    const std::shared_ptr<ngraph::runtime::Backend>& backend,
    std::function<std::shared_ptr<ngraph::Function>()> make_graph,
    const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& args,
    T rtol,
    T atol,
    const std::vector<bool>& indep_param_mask)
{
87
    // Use INTERPRETER to compute numerical derivatives
adstraw's avatar
adstraw committed
88 89 90 91 92 93 94 95 96 97 98 99 100 101
    std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params;
    auto f = make_graph();

    size_t i = 0;

    for (auto b : indep_param_mask)
    {
        if (b)
        {
            f_indep_params.push_back(f->get_parameters().at(i));
        }
        i++;
    }

102
    auto interpreter_backend = ngraph::runtime::Backend::create("INTERPRETER");
adstraw's avatar
adstraw committed
103

104 105 106
    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
    for (auto arg : args)
    {
107
        auto interpreter_arg = interpreter_backend->create_tensor(
108 109 110 111 112 113 114 115 116 117 118 119 120 121
            arg->get_tensor().get_element_type(), arg->get_shape());

        // TODO: copy_data should not require T. Quick fix here for bool used in `Select`
        if (arg->get_tensor().get_element_type() == ngraph::element::boolean)
        {
            copy_data(interpreter_arg, read_vector<char>(arg));
        }
        else
        {
            copy_data(interpreter_arg, read_vector<T>(arg));
        }
        interpreter_args.push_back(interpreter_arg);
    }
    auto results_num = ngraph::autodiff::numeric_derivative<T>(
122
        interpreter_backend, f, interpreter_args, .001f, f_indep_params);
123 124

    // Use the backend being tested to compute symbolic derivatives
adstraw's avatar
adstraw committed
125 126 127 128 129 130 131 132 133 134 135 136 137 138
    std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params;
    auto g = make_graph();

    i = 0;

    for (auto b : indep_param_mask)
    {
        if (b)
        {
            g_indep_params.push_back(g->get_parameters().at(i));
        }
        i++;
    }

139
    auto results_sym = ngraph::autodiff::backprop_derivative<T>(backend, g, args, g_indep_params);
adstraw's avatar
adstraw committed
140

141 142 143 144
    // Cast to HostTensorView for comparision
    std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_results_sym;
    for (auto result : results_sym)
    {
145 146
        auto interpreter_result =
            interpreter_backend->create_tensor(ngraph::element::from<T>(), result->get_shape());
147 148 149 150 151
        copy_data(interpreter_result, read_vector<T>(result));
        interpreter_results_sym.push_back(interpreter_result);
    }

    return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
adstraw's avatar
adstraw committed
152
}