/* // Copyright (c) 2016 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. */ /////////////////////////////////////////////////////////////////////////////////////////////////// #pragma once #include "api_extension/fused_conv_eltwise.hpp" #include "primitive_inst.h" #include <memory> #include <string> namespace cldnn { template <> struct typed_program_node<fused_conv_eltwise> : public typed_program_node_base<fused_conv_eltwise> { using parent = typed_program_node_base<fused_conv_eltwise>; public: typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog) : parent(prim, prog), split(this->get_primitive()->split()), depthwise_sep_opt(false), transposed(false), conv_input_qf(this->get_primitive()->conv.input_quantization_factor), conv_output_qf(this->get_primitive()->conv.output_quantization_factor) { if (get_primitive()->eltw.with_activation) { auto slope = get_primitive()->eltw.activation_negative_slope; if (slope == 0.f) { this->add_fused_activation(activation_func::relu, {}); } else { this->add_fused_activation(activation_func::relu_negative_slope, { slope, 0.f }); } } } void set_split(int32_t node_split) { split = node_split; } int32_t get_split() const { return split; } void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; } bool get_depthwise_sep_opt() const { return depthwise_sep_opt; } void set_transposed(bool node_transposed) { transposed = node_transposed; } bool get_transposed() const { return transposed; } program_node& input(size_t idx = 0) const { if (static_cast<int32_t>(idx) >= static_cast<int32_t>(desc->input.size())) throw std::range_error("input index too big"); return get_dependency(idx); } program_node& weights(size_t idx = 0) const { if (static_cast<int32_t>(idx) >= this->get_split()) throw std::range_error("weights offset too big"); return get_dependency(desc->input.size() + idx); } program_node& bias(size_t idx = 0) const { if (static_cast<int32_t>(idx) >= this->get_split()) throw std::range_error("bias offset too big"); return get_dependency(desc->input.size() + this->get_split() + idx); } program_node& weights_quantization_factors(size_t idx = 0) const { if (static_cast<int32_t>(idx) >= this->get_split()) throw std::range_error("quantization factor offset too big"); return get_dependency(desc->input.size() + (1 + 1 * bias_term()) * this->get_split() + idx); } program_node& conv_output_calibration_factors(size_t idx = 0) const { if (static_cast<int32_t>(idx) >= this->get_split()) throw std::range_error("calibration factor offset too big"); return get_dependency(desc->input.size() + (1 + 1 * bias_term() + 1 * weights_quantization_term()) * this->get_split() + idx); } program_node& eltw_output_calibration_factors() const { return get_dependency(desc->input.size() + (1 + 1 * bias_term() + 1 * weights_quantization_term() + 1 * conv_output_calibration_term()) * this->get_split()); } bool bias_term() const { return get_primitive()->conv.bias.size() > 0; } bool weights_quantization_term() const { return get_primitive()->conv.weights_quantization_factors.size() > 0; } bool conv_output_calibration_term() const { return get_primitive()->conv.output_calibration_factors.size() > 0; } bool eltw_output_calibration_term() const { return get_primitive()->eltw.output_calibration_factors.size() > 0; } float get_conv_input_qf() const { return conv_input_qf; } float get_conv_output_qf() const { return conv_output_qf; } float get_eltw_output_qf() const { return eltw_output_qf; } private: int32_t split; bool depthwise_sep_opt; bool transposed; float conv_input_qf; float conv_output_qf; float eltw_output_qf; }; using fused_conv_eltwise_node = typed_program_node<fused_conv_eltwise>; template <> class typed_primitive_inst<fused_conv_eltwise> : public typed_primitive_inst_base<fused_conv_eltwise> { using parent = typed_primitive_inst_base<fused_conv_eltwise>; public: static layout calc_output_layout(fused_conv_eltwise_node const& node); static std::string to_string(fused_conv_eltwise_node const& node); public: typed_primitive_inst(network_impl& network, fused_conv_eltwise_node const& node); memory_impl& weights_memory(size_t index) const { if (static_cast<int32_t>(index) >= node.get_split()) throw std::range_error("weights offset too big"); return dep_memory(2 + index); } memory_impl& bias_memory(size_t index) const { if (static_cast<int32_t>(index) >= node.get_split()) throw std::range_error("bias offset too big"); return dep_memory(2 + node.get_split() + index); } memory_impl& weights_quantization_factors_memory(size_t index) const { if (static_cast<int32_t>(index) >= node.get_split()) throw std::range_error("quantization factors offset too big"); return dep_memory(2 + (1 + 1 * bias_term()) * node.get_split() + index); } memory_impl& output_calibration_factors_memory(size_t index) const { if (static_cast<int32_t>(index) >= node.get_split()) throw std::range_error("quantization factors offset too big"); return dep_memory(2 + (1 + 1 * bias_term() + 1 * weights_quantization_factors_term()) * node.get_split() + index); } memory_impl& eltw_output_calibration_factors_memory() const { return dep_memory(2 + (1 + 1 * bias_term() + 1 * weights_quantization_factors_term() + 1 * conv_output_calibration_factors_term()) * node.get_split()); } bool bias_term() const { return node.bias_term(); } bool weights_quantization_factors_term() const { return node.weights_quantization_term(); } bool conv_output_calibration_factors_term() const { return node.conv_output_calibration_term(); } bool eltw_output_calibration_factors_term() const { return node.eltw_output_calibration_term(); } }; using fused_conv_eltwise_inst = typed_primitive_inst<fused_conv_eltwise>; } // namespace cldnn