/*
// Copyright (c) 2016-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "api/convolution.hpp"
#include "primitive_inst.h"

#include <memory>
#include <string>

namespace cldnn {

template <>
struct typed_program_node<deformable_conv> : public typed_program_node_base<deformable_conv> {
    using parent = typed_program_node_base<deformable_conv>;

public:
    typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
            : parent(prim, prog),
              split(this->get_primitive()->split()),
              depthwise_sep_opt(false),
              groups(this->get_primitive()->groups) {
        support_padding_all(true);
    }

    void set_split(int32_t node_split) { split = node_split; }
    int32_t get_split() const { return split; }

    void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
    bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }

    void set_transposed(bool node_transposed) { transposed = node_transposed; }
    bool get_transposed() const { return transposed; }

    void set_groups(uint32_t node_groups) { groups = node_groups; }
    uint32_t get_groups() const { return groups; }

    program_node& input() const { return get_dependency(0); }

    program_node& weights(size_t idx = 0) const {
        if (static_cast<int32_t>(idx) >= this->get_split())
            throw std::range_error("weights offset too big");

        return get_dependency(1 + idx);
    }

    program_node& bias(size_t idx = 0) const {
        if (static_cast<int32_t>(idx) >= this->get_split())
            throw std::range_error("bias offset too big");

        return get_dependency(1 + this->get_split() + idx);
    }

    bool bias_term() const { return get_primitive()->bias.size() > 0; }

private:
    int32_t split;
    bool depthwise_sep_opt;
    bool transposed;
    uint32_t groups;
};

using deformable_conv_node = typed_program_node<deformable_conv>;

template <>
class typed_primitive_inst<deformable_conv> : public typed_primitive_inst_base<deformable_conv> {
    using parent = typed_primitive_inst_base<deformable_conv>;

public:
    static layout calc_output_layout(deformable_conv_node const& node);
    static std::string to_string(deformable_conv_node const& node);

public:
    typed_primitive_inst(network_impl& network, deformable_conv_node const& node);

    memory_impl& weights_memory(size_t index) const {
        if (node.get_groups() == 1) {
            if (static_cast<int32_t>(index) >= node.get_split())
                throw std::range_error("weights offset too big");
            return dep_memory(1 + index);
        } else {  // all weights are in one buffer
            return dep_memory(1);
        }
    }

    memory_impl& bias_memory(size_t index) const {
        if (node.get_groups() == 1) {
            if (static_cast<int32_t>(index) >= node.get_split())
                throw std::range_error("bias offset too big");
            return dep_memory(1 + node.get_split());
        } else {  // all bias are in one buffer
            return dep_memory(2);
        }
    }

    bool bias_term() const { return node.bias_term(); }
};

using deformable_conv_inst = typed_primitive_inst<deformable_conv>;

template <>
struct typed_program_node<deformable_interp> : public typed_program_node_base<deformable_interp> {
    using parent = typed_program_node_base<deformable_interp>;

public:
    typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
            : parent(prim, prog),
              split(1),
              depthwise_sep_opt(false),
              transposed(false),
              groups(this->get_primitive()->groups),
              deformable_groups(this->get_primitive()->deformable_groups) {
        support_padding_all(true);
    }

    void set_split(int32_t node_split) { split = node_split; }
    int32_t get_split() const { return split; }

    void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
    bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }

    void set_transposed(bool node_transposed) { transposed = node_transposed; }
    bool get_transposed() const { return transposed; }

    void set_groups(uint32_t node_groups) { groups = node_groups; }
    uint32_t get_groups() const { return groups; }

    void set_deformable_groups(uint32_t node_deformable_groups) { deformable_groups = node_deformable_groups; }
    uint32_t get_deformable_groups() const { return deformable_groups; }

    program_node& input() const { return get_dependency(0); }
    program_node& trans() const { return get_dependency(1); }

private:
    int32_t split;
    bool depthwise_sep_opt;
    bool transposed;
    uint32_t groups;
    uint32_t deformable_groups;
};

using deformable_interp_node = typed_program_node<deformable_interp>;

template <>
class typed_primitive_inst<deformable_interp> : public typed_primitive_inst_base<deformable_interp> {
    using parent = typed_primitive_inst_base<deformable_interp>;

public:
    static layout calc_output_layout(deformable_interp_node const& node);
    static std::string to_string(deformable_interp_node const& node);

public:
    typed_primitive_inst(network_impl& network, deformable_interp_node const& node);

    memory_impl& trans_memory() const { return dep_memory(1); }
};

using deformable_interp_inst = typed_primitive_inst<deformable_interp>;

}  // namespace cldnn