numeric_derivative.hpp 4.16 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2020 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
Scott Cyphers's avatar
Scott Cyphers committed
16 17 18 19 20 21 22

#pragma once

#include <memory>
#include <vector>

#include "ngraph/runtime/backend.hpp"
23
#include "ngraph/type/element_type.hpp"
Scott Cyphers's avatar
Scott Cyphers committed
24 25 26 27 28

namespace ngraph
{
    namespace autodiff
    {
29 30 31 32 33 34
        /// \brief numeric approximation of the derivative
        /// \param f A function
        /// \param args Values for the arguments (the independent variables)
        /// \param delta increment for the variables
        /// \param indep_params parameters with respect to which to compute derivatives
        /// \returns vector of dy/dvar, where each dy/dvar's shape is concat(y.shape(), var.shape())
35
        template <typename T>
36
        std::vector<std::shared_ptr<runtime::Tensor>>
37
            numeric_derivative(runtime::Backend* backend,
38
                               const std::shared_ptr<Function>& f,
39
                               const std::vector<std::shared_ptr<runtime::Tensor>>& args,
40 41
                               T delta,
                               const std::vector<std::shared_ptr<op::Parameter>>& indep_params)
Robert Kimball's avatar
Robert Kimball committed
42
        {
43
            Shape y_shape = f->get_output_shape(0);
Scott Cyphers's avatar
Scott Cyphers committed
44

Robert Kimball's avatar
Robert Kimball committed
45 46 47
            auto params = f->get_parameters();

            // Results for each derivative, shape Y|X_i
48
            std::vector<std::shared_ptr<runtime::Tensor>> results;
49 50

            for (auto param : indep_params)
Robert Kimball's avatar
Robert Kimball committed
51 52
            {
                Shape s = y_shape;
53
                auto param_shape = param->get_shape();
Robert Kimball's avatar
Robert Kimball committed
54
                s.insert(s.end(), param_shape.begin(), param_shape.end());
55
                results.push_back(backend->create_tensor<T>(s));
Robert Kimball's avatar
Robert Kimball committed
56 57 58
            }

            // ref_y is the function evaluated at the args
59
            auto ref_y = backend->create_tensor<T>(y_shape);
Robert Kimball's avatar
Robert Kimball committed
60

61
            auto f_handle = backend->compile(f);
62

63 64
            f_handle->call_with_validate(
                std::vector<std::shared_ptr<ngraph::runtime::Tensor>>{ref_y}, args);
65
            auto ref_vec = read_vector<T>(ref_y);
Robert Kimball's avatar
Robert Kimball committed
66 67

            // inc_y will hold f(x+dx) values
68
            auto inc_y = backend->create_tensor<T>(y_shape);
Robert Kimball's avatar
Robert Kimball committed
69 70 71

            // Assuming vars, y, and results are row-major

72
            T inv_delta = 1 / delta;
73 74 75

            size_t pos = 0;

Robert Kimball's avatar
Robert Kimball committed
76 77
            for (size_t i = 0; i < args.size(); ++i)
            {
78 79
                if (std::find(indep_params.begin(), indep_params.end(), params[i]) !=
                    indep_params.end())
Robert Kimball's avatar
Robert Kimball committed
80
                {
81
                    auto arg = args[i];
82 83
                    auto res = read_vector<T>(results[pos]);
                    auto vec = read_vector<T>(arg);
84
                    for (size_t j = 0; j < vec.size(); j++)
Robert Kimball's avatar
Robert Kimball committed
85
                    {
86 87
                        auto old_val = vec[j];
                        vec[j] += delta;
88
                        write_vector(arg, vec);
89
                        f_handle->call_with_validate({inc_y}, args);
90
                        auto inc_vec = read_vector<T>(inc_y);
91
                        vec[j] = old_val;
92
                        write_vector(arg, vec);
93 94 95 96 97 98 99 100
                        size_t res_k = j;
                        for (size_t k = 0; k < inc_vec.size(); k++)
                        {
                            auto y1 = inc_vec[k];
                            auto y0 = ref_vec[k];
                            res[res_k] = inv_delta * (y1 - y0);
                            res_k += vec.size();
                        }
Robert Kimball's avatar
Robert Kimball committed
101
                    }
102
                    write_vector(results[pos], res);
103
                    pos++;
Robert Kimball's avatar
Robert Kimball committed
104 105 106 107
                }
            }
            return results;
        }
Scott Cyphers's avatar
Scott Cyphers committed
108 109
    }
}