gather_gpu.cpp 2.81 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/*
// Copyright (c) 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

#include "gather_inst.h"
#include "primitive_gpu_base.h"
#include "implementation_map.h"
#include "kernel_selector_helper.h"
#include "gather/gather_kernel_selector.h"
#include "gather/gather_kernel_ref.h"
#include "error_handler.h"

using namespace cldnn;

27 28 29 30 31 32 33 34 35 36 37 38 39 40
namespace cldnn {
namespace gpu {
kernel_selector::gather_axis convert_axis(gather::gather_axis axis) {
    switch (axis) {
        case gather::along_x:
            return kernel_selector::gather_axis::X;
        case gather::along_y:
            return kernel_selector::gather_axis::Y;
        case gather::along_f:
            return kernel_selector::gather_axis::FEATURE;
        case gather::along_b:
            return kernel_selector::gather_axis::BATCH;
        default:
            return kernel_selector::gather_axis::X;
41
    }
42
}
43

44 45 46
struct gather_gpu : typed_primitive_gpu_impl<gather> {
    using parent = typed_primitive_gpu_impl<gather>;
    using parent::parent;
47

48 49 50 51 52
public:
    static primitive_impl* create(const gather_node& arg) {
        auto gather_params = get_default_params<kernel_selector::gather_params>(arg);
        auto gather_optional_params =
            get_default_optional_params<kernel_selector::gather_optional_params>(arg.get_program());
53

54
        gather_params.axis = convert_axis(arg.get_primitive()->axis);
55

56
        gather_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
57

58 59
        auto& kernel_selector = kernel_selector::gather_kernel_selector::Instance();
        auto best_kernels = kernel_selector.GetBestKernels(gather_params, gather_optional_params);
60

61 62 63 64
        CLDNN_ERROR_BOOL(arg.id(),
                         "Best_kernel.empty()",
                         best_kernels.empty(),
                         "Cannot find a proper kernel with this arguments");
65

66
        auto gather = new gather_gpu(arg, best_kernels[0]);
67

68 69 70
        return gather;
    }
};
71

72 73 74 75 76 77 78 79 80
namespace detail {

attach_gather_gpu::attach_gather_gpu() {
    auto val_fw = gather_gpu::create;
    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
    implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
}

}  // namespace detail
81 82
}  // namespace gpu
}  // namespace cldnn