/* // Copyright (c) 2018 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. */ #include "scale_grad_weights_inst.h" #include "primitive_gpu_base.h" #include "implementation_map.h" #include "kernel_selector_helper.h" #include "scale_grad_weights/scale_grad_weights_kernel_selector.h" #include "scale_grad_weights/scale_grad_weights_kernel_base.h" #include "error_handler.h" #include "network_impl.h" using namespace cldnn; namespace cldnn { namespace gpu { struct scale_grad_weights_gpu : typed_primitive_gpu_impl<scale_grad_weights> { using parent = typed_primitive_gpu_impl<scale_grad_weights>; using parent::parent; protected: kernel::kernel_arguments_data get_arguments(typed_primitive_inst<scale_grad_weights>& instance, int32_t) const override { kernel::kernel_arguments_data args; args.inputs = {(memory_impl::cptr) &instance.input_memory(0), (memory_impl::cptr) &instance.input_memory(1)}; args.output = (memory_impl::cptr) &instance.output_memory(); args.bias = (memory_impl::cptr) (_outer.bias_term() ? &instance.bias_memory() : nullptr); args.weights = (memory_impl::cptr) &instance.weights_memory(); args.prev_weights_grad = (memory_impl::cptr) (instance.use_momentum() ? &instance.prev_scale_grad() : nullptr); args.prev_bias_grad = (memory_impl::cptr) (instance.bias_term() ? instance.use_momentum() ? &instance.prev_bias_grad() : nullptr : nullptr); args.lr = instance.get_network().get_learning_rate(); return args; } public: static primitive_impl* create(const scale_grad_weights_node& arg) { auto scale_params = get_default_learning_params<kernel_selector::scale_grad_weights_params>(arg); auto scale_optional_params = get_default_learning_optional_params<kernel_selector::scale_grad_weights_optional_params>( arg.get_program()); auto& kernel_selector = kernel_selector::scale_grad_weights_kernel_selector::Instance(); auto best_kernels = kernel_selector.GetBestKernels(scale_params, scale_optional_params); CLDNN_ERROR_BOOL(arg.id(), "Best_kernel.empty()", best_kernels.empty(), "Cannot find a proper kernel with this arguments"); auto scale_grad_weights = new scale_grad_weights_gpu(arg, best_kernels[0]); return scale_grad_weights; } }; namespace { struct attach { attach() { auto val_fw = scale_grad_weights_gpu::create; implementation_map<scale_grad_weights>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw); implementation_map<scale_grad_weights>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw); implementation_map<scale_grad_weights>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw); implementation_map<scale_grad_weights>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw); implementation_map<scale_grad_weights>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw); implementation_map<scale_grad_weights>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw); } ~attach() {} }; attach attach_impl; } // namespace } // namespace gpu } // namespace cldnn