numeric_derivative.hpp 4.12 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
Scott Cyphers's avatar
Scott Cyphers committed
16 17 18 19 20 21 22

#pragma once

#include <memory>
#include <vector>

#include "ngraph/runtime/backend.hpp"
23
#include "ngraph/type/element_type.hpp"
Scott Cyphers's avatar
Scott Cyphers committed
24 25 26 27 28 29 30 31 32

namespace ngraph
{
    namespace autodiff
    {
        /// @brief numeric approximation of the derivative
        /// @param f A function
        /// @param args Values for the arguments (the independent variables)
        /// @param delta increment for the variables
33
        /// @param indep_params parameters with respect to which to compute derivatives
Scott Cyphers's avatar
Scott Cyphers committed
34
        /// @returns vector of dy/dvar, where each dy/dvar's shape is concat(y.shape(), var.shape())
35 36
        template <typename T>
        std::vector<std::shared_ptr<runtime::TensorView>>
37
            numeric_derivative(const std::shared_ptr<runtime::Backend>& backend,
38 39
                               const std::shared_ptr<Function>& f,
                               const std::vector<std::shared_ptr<runtime::TensorView>>& args,
40 41
                               T delta,
                               const std::vector<std::shared_ptr<op::Parameter>>& indep_params)
Robert Kimball's avatar
Robert Kimball committed
42
        {
43
            Shape y_shape = f->get_output_shape(0);
Scott Cyphers's avatar
Scott Cyphers committed
44

Robert Kimball's avatar
Robert Kimball committed
45 46 47
            auto params = f->get_parameters();

            // Results for each derivative, shape Y|X_i
48
            std::vector<std::shared_ptr<runtime::TensorView>> results;
49 50

            for (auto param : indep_params)
Robert Kimball's avatar
Robert Kimball committed
51 52
            {
                Shape s = y_shape;
53
                auto param_shape = param->get_shape();
Robert Kimball's avatar
Robert Kimball committed
54
                s.insert(s.end(), param_shape.begin(), param_shape.end());
55
                results.push_back(backend->create_tensor<T>(s));
Robert Kimball's avatar
Robert Kimball committed
56 57 58
            }

            // ref_y is the function evaluated at the args
59
            auto ref_y = backend->create_tensor<T>(y_shape);
Robert Kimball's avatar
Robert Kimball committed
60

61 62
            backend->call(
                f, std::vector<std::shared_ptr<ngraph::runtime::TensorView>>{ref_y}, args);
63
            auto ref_vec = read_vector<T>(ref_y);
Robert Kimball's avatar
Robert Kimball committed
64 65

            // inc_y will hold f(x+dx) values
66
            auto inc_y = backend->create_tensor<T>(y_shape);
Robert Kimball's avatar
Robert Kimball committed
67 68 69

            // Assuming vars, y, and results are row-major

70
            T inv_delta = 1 / delta;
71 72 73

            size_t pos = 0;

Robert Kimball's avatar
Robert Kimball committed
74 75
            for (size_t i = 0; i < args.size(); ++i)
            {
76 77
                if (std::find(indep_params.begin(), indep_params.end(), params[i]) !=
                    indep_params.end())
Robert Kimball's avatar
Robert Kimball committed
78
                {
79
                    auto arg = args[i];
80 81
                    auto res = read_vector<T>(results[pos]);
                    auto vec = read_vector<T>(arg);
82
                    for (size_t j = 0; j < vec.size(); j++)
Robert Kimball's avatar
Robert Kimball committed
83
                    {
84 85
                        auto old_val = vec[j];
                        vec[j] += delta;
86
                        write_vector(arg, vec);
87
                        backend->call(f, {inc_y}, args);
88
                        auto inc_vec = read_vector<T>(inc_y);
89
                        vec[j] = old_val;
90
                        write_vector(arg, vec);
91 92 93 94 95 96 97 98
                        size_t res_k = j;
                        for (size_t k = 0; k < inc_vec.size(); k++)
                        {
                            auto y1 = inc_vec[k];
                            auto y0 = ref_vec[k];
                            res[res_k] = inv_delta * (y1 - y0);
                            res_k += vec.size();
                        }
Robert Kimball's avatar
Robert Kimball committed
99
                    }
100
                    write_vector(results[pos], res);
101
                    pos++;
Robert Kimball's avatar
Robert Kimball committed
102 103 104 105
                }
            }
            return results;
        }
Scott Cyphers's avatar
Scott Cyphers committed
106 107
    }
}