scale_factor_calc.hpp 23.6 KB
Newer Older
1
// Copyright (C) 2018-2019 Intel Corporation
2 3 4 5 6 7 8 9 10
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include <vector>
#include <algorithm>
#include <utility>
#include <limits>
#include <string>
11
#include <map>
12
#include <gna_upstream_iterator.hpp>
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
#include "gna_layer_info.hpp"
#include "ie_layers.h"
#include "gna_plugin_log.hpp"

namespace GNAPluginNS {
namespace details {
using namespace InferenceEngine;
struct ScaleFactorUpdateResult {
    CNNLayer *restartLayer = nullptr;
    ScaleFactorUpdateResult() = default;
    explicit ScaleFactorUpdateResult(CNNLayer * restartlayer) : restartLayer(restartlayer) {
    }
    operator bool() {
        return restartLayer == nullptr;
    }
};

/**
 * @brief calculates output scale factor per layer
 * @tparam T
 */
template<class T>
class ScaleFactorPerLayer {
 public:
    /**
     * @brief calculates weights scale factor for fit dynamic range into target bitsize,
     * also calculates output scale factor for the given layer
     * @param cnnLayer
     * @param weightsSize
     * @param result
     * @return
     */
45
    bool operator()(T cnnLayer, int weightsSize, ScaleFactorUpdateResult &result) {
46 47 48 49 50 51 52 53 54 55 56
        return false;
    }
};

template<>
class ScaleFactorPerLayer<InferenceEngine::CNNLayer *> {
 private :
    const float activation_scale_factor = 2048.f;
    const float identity_scale_factor = 2049.0f;
    const float k = 5;
    const float k_identity = 6;
57 58 59 60 61 62 63 64 65 66

 protected :
    static bool fp32eq(float p1, float p2) {
        return (std::abs(p1 - p2) <= 0.00001f * std::min(std::abs(p1), std::abs(p2)));
    }
    float getActivationScale(GNAPluginNS::LayerInfo const&  layer, QuantizedLayerParams const* qunatizedParams) {
            // todo: calculate proper scale factor where we need to expand it a bit to be safe to stay in int16 weights
            // set the initial value
            float result = 1.0f;
            result = (layer.isIdentity()) ? identity_scale_factor : activation_scale_factor;
67
            // if activation is one from relu family, we need to apply heuristic to avoid activation output overflow
68 69 70 71 72 73 74 75
            if (layer.isRelu() &&
                    static_cast<uint64_t>(result * qunatizedParams->_src_quant.scale)
                                                                > std::numeric_limits<int32_t>::max()-1) {
                result = (result * 0.5);
            }
            return result;
    }

76
 public :
77
    bool operator()(InferenceEngine::CNNLayer *cnnLayer, int weightsSize, ScaleFactorUpdateResult &result) {
78 79 80 81 82 83
        if ( !cnnLayer ) {
            THROW_IE_EXCEPTION << "Incorrect Convolutional Layer pointer \n";
        }
        LayerInfo layerInfo(*cnnLayer);
        // TODO: current approach set input scale factor for true input layer(s) equals to provided factor,
        auto quant = getInjectedData<QuantizedLayerParams>(*cnnLayer);
84

85
        if (InferenceEngine::details::CaselessEq<std::string>()(cnnLayer->type, "Memory")) {
86
             if (CNNNetHasPrevLayer(cnnLayer)) {
87
                auto prevLayer = CNNNetPrevLayer(cnnLayer);
88
                auto prevInfo = LayerInfo(prevLayer);
89
                auto inputQuant = getInjectedData<QuantizedLayerParams>(prevLayer);
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
               // locating corresponding memory layers ith same ID
                for (auto && input : CNNNetGetAllInputLayers(cnnLayer)) {
                    LayerInfo ll(input);
                    if (!ll.isMemory() ||
                        !InferenceEngine::details::CaselessEq<std::string>()(input->params["id"], cnnLayer->params["id"])) {
                        continue;
                    }

                    auto quantSibling = getInjectedData<QuantizedLayerParams>(input);

                    // after restarting from memory input - quant is fine
                    if (fp32eq(quantSibling->_dst_quant.scale, inputQuant->_dst_quant.scale)) {
                        quant->_src_quant.scale = quant->_dst_quant.scale = inputQuant->_dst_quant.scale;
                        return true;
                    }

                    if (!fp32eq(quantSibling->_dst_quant.scale, 1)) {
                        // means we already restarted propagation from that memory layer - we cannot do mach here
                        THROW_GNA_EXCEPTION << "quantization error : input scale factor ( " << inputQuant->_dst_quant.scale <<") "
                                  << " for " << cnnLayer->name << ", that is child of " << prevLayer->name <<" doesnt match : "
                                  << activation_scale_factor;
                    }

                    gnawarn() << "[INFO] quantization : input scale factor (" << inputQuant->_dst_quant.scale <<")"
                              << " for " << cnnLayer->name << ", that is child of " << prevLayer->name <<" doesnt match : "
                              << activation_scale_factor << ", restarting from corresponding memory: "<< input->name << std::endl;

                    // try updating memory input layer scale factor and restart from it
                    quantSibling->_src_quant.scale = quantSibling->_dst_quant.scale = inputQuant->_dst_quant.scale;
                    result = ScaleFactorUpdateResult(input.get());
120 121 122 123 124 125 126
                    return true;
                }
            }
            return true;
        }

        if (!CNNNetHasPrevLayer(cnnLayer)) {
127
            quant->_dst_quant.scale = quant->_src_quant.scale;
128 129 130 131 132
            return ScaleFactorUpdateResult();
        }

        // by default layer is pass thru its scale factor
        auto inputQuant = getInjectedData<QuantizedLayerParams>(CNNNetPrevLayer(cnnLayer));
133 134 135
        if (!inputQuant) {
            THROW_GNA_EXCEPTION << "layer: " << CNNNetPrevLayer(cnnLayer)->name << "not quantized";
        }
136 137 138 139 140 141
        quant->_dst_quant.scale = inputQuant->_dst_quant.scale;
        quant->_src_quant.scale = inputQuant->_dst_quant.scale;

        if (layerInfo.isActivation()) {
            // todo: calculate proper scale factor where we need to expand it a bit to be safe to stay in int16 weights
            // set the initial value
142
            quant->_dst_quant.scale = getActivationScale(layerInfo, quant);
143 144 145 146 147 148 149 150
        }
        return true;
    }
};

template<>
class ScaleFactorPerLayer<InferenceEngine::EltwiseLayer*> {
 public:
151
    bool operator()(InferenceEngine::EltwiseLayer* eltwiseLayer, int weightsSize, ScaleFactorUpdateResult &result) {
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
        if ( !eltwiseLayer ) {
            THROW_GNA_EXCEPTION << "Incorrect Eltwise Layer pointer \n";
        }
        auto in0 = InferenceEngine::CNNNetPrevLayer(eltwiseLayer, 0);
        auto in1 = InferenceEngine::CNNNetPrevLayer(eltwiseLayer, 1);

        auto quantParams0 = InferenceEngine::getInjectedData<QuantizedLayerParams>(in0);
        auto quantParams1 = InferenceEngine::getInjectedData<QuantizedLayerParams>(in1);
        auto quantData = InferenceEngine::getInjectedData<QuantizedLayerParams>(*eltwiseLayer);

        switch (eltwiseLayer->_operation) {
            case InferenceEngine::EltwiseLayer::Prod: {
                quantData->_weights_quant.scale = quantParams1->_dst_quant.scale;
                quantData->_dst_quant.scale     = quantParams0->_dst_quant.scale * quantParams1->_dst_quant.scale;
                break;
            }
            case InferenceEngine::EltwiseLayer::Sum: {
                // detect which input will be used as biases
                if (LayerInfo(in0).has32BOutput()) {
                    std::swap(in0, in1);
                    std::swap(quantParams0, quantParams1);
                }

                // this path might result in significant data loss
                quantData->_weights_quant.scale = quantParams1->_dst_quant.scale / quantParams0->_dst_quant.scale;
                quantData->_dst_quant.scale = quantParams1->_dst_quant.scale;

                // eltwise will always work in int16
                auto maxValue = std::numeric_limits<int16_t>::max() - 1;
                if (quantData->_weights_quant.scale > maxValue + 1) {
                    // rescaling it's activation input
                    // iterating thru previous layers of eltwise
                    for (uint8_t i = 0; i < 2; ++i) {
                        InferenceEngine::CNNLayerPtr in = InferenceEngine::CNNNetPrevLayer(eltwiseLayer, i);
                        // trick to get opposite index (for 0 -> 1 for 1 -> 0) by inversing i.
                        auto quantParams =
                                InferenceEngine::getInjectedData<QuantizedLayerParams>(InferenceEngine::CNNNetPrevLayer(eltwiseLayer, !i));

                        for (; InferenceEngine::CNNNetHasPrevLayer(in.get()); in = CNNNetPrevLayer(in)) {
                            auto info = LayerInfo(in);
                            // we skipping only split layers so far, also need to work on memory layers
                            // this case for input from port 0
                            if (info.isSplit() || info.isSlice()) {
                                continue;
                            } else if (info.has16BOutput() && info.isActivation()) {
                                auto newOutputScale = quantParams->_dst_quant.scale / maxValue;
198
                                if (newOutputScale > static_cast<float>(std::numeric_limits<int16_t>::max()) / 2) {
199 200 201 202 203 204 205 206 207 208 209 210 211 212
                                    break;
                                }
                                auto quantDataForActivation = InferenceEngine::getInjectedData<QuantizedLayerParams>(*in);
                                gnawarn() << "[WARNING] saturated weights for " << eltwiseLayer->name
                                         << ". Layer new output scale: " << in->name << ", output_scale=" << newOutputScale
                                         << ", was " << quantDataForActivation->_dst_quant.scale <<"\n" << std::flush;
                                quantDataForActivation->_dst_quant.scale = newOutputScale;
                                result = ScaleFactorUpdateResult(in.get());
                                return true;
                            } else if (info.has16BOutput()) {
                                break;
                            }

                            // if we are here it means that we are in the port 1
213
                            if (info.isFullyConnected() || info.isConvolution()) {
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
                                auto quantDataForInputLayer = InferenceEngine::getInjectedData<QuantizedLayerParams>(*in);
                                auto newOutputScale = quantParams->_dst_quant.scale * maxValue;
                                auto newWeightScale = newOutputScale / quantDataForInputLayer->_src_quant.scale;
                                quantDataForInputLayer->_dst_quant.scale = newOutputScale;
                                quantDataForInputLayer->_weights_quant.scale = newWeightScale;
                                result = ScaleFactorUpdateResult(in.get());
                                return true;
                            }
                        }
                    }
                    // we unable to rescale the input - results might be bad
                    gnawarn() << "[INFO] weights saturated for " << eltwiseLayer->name << "\n";
                }
                break;
            }
            default : THROW_GNA_EXCEPTION << "Unsupported Eltwise layer for quantisation: " << eltwiseLayer->_operation;
        }
        return true;
    }
};

235 236 237
template<>
class ScaleFactorPerLayer<InferenceEngine::ConcatLayer*> {
 public:
238
    bool operator()(InferenceEngine::ConcatLayer* concatLayer, int weightsSize, ScaleFactorUpdateResult &result) {
239 240 241 242 243 244 245 246 247 248 249 250
        if ( !concatLayer ) {
            THROW_GNA_EXCEPTION << "Incorrect Concat Layer pointer \n";
        }
        auto in0 = InferenceEngine::CNNNetPrevLayer(concatLayer, 0);
        auto in1 = InferenceEngine::CNNNetPrevLayer(concatLayer, 1);
        auto infoIn0 = LayerInfo(in0);
        auto infoIn1 = LayerInfo(in1);
        auto quantParams0 = InferenceEngine::getInjectedData<QuantizedLayerParams>(in0);
        auto quantParams1 = InferenceEngine::getInjectedData<QuantizedLayerParams>(in1);
        GNAPluginNS::QuantizedLayerParams* sourceQuantParams = NULL;
        auto quantData = InferenceEngine::getInjectedData<QuantizedLayerParams>(*concatLayer);

251 252 253 254 255 256 257 258
        auto fp32eq = [](float p1, float p2) -> bool {
            return (std::abs(p1 - p2) <= 0.00001f * std::min(std::abs(p1), std::abs(p2)));
        };

        // if both inputs have same quant value - trivial propagation
        if (fp32eq(quantParams0->_dst_quant.scale, quantParams1->_dst_quant.scale)) {
            quantData->_dst_quant.scale = quantParams0->_dst_quant.scale;
            quantData->_src_quant.scale = quantParams0->_dst_quant.scale;
259
            return true;
260 261 262
        }
        // support only cases when one of input is network input
        if (infoIn0.isInput() && infoIn1.isInput()) {
263 264 265
            THROW_GNA_EXCEPTION << "Two Input layers has different scales in concat!!! \n";
        }

266 267
        int concatIdxToUpdate = -1;

268 269 270
        if (infoIn0.isInput()) {
            sourceQuantParams = quantParams0;
        } else if (infoIn1.isInput()) {
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
            concatIdxToUpdate = 0;
            sourceQuantParams = quantParams1;
        }

        // possible case when some of the concat inputs are free to select scale ex: const->concat<-affine
        if (quantParams1->_dst_quant.scale == 1.0) {
            quantParams1->_weights_quant = quantParams0->_dst_quant;
            quantParams1->_dst_quant     = quantParams0->_dst_quant;

            sourceQuantParams = quantParams0;
        }

        if (quantParams0->_dst_quant.scale == 1.0) {
            quantParams0->_weights_quant = quantParams1->_dst_quant;
            quantParams0->_dst_quant     = quantParams1->_dst_quant;
286 287 288 289
            sourceQuantParams = quantParams1;
        }

        if (!sourceQuantParams) {
290 291 292 293 294 295 296 297 298 299 300 301 302
            auto in0LayerInfo = LayerInfo(in0);
            auto in1LayerInfo = LayerInfo(in1);
            if (in0LayerInfo.isActivation()) {
                quantParams0->_weights_quant = quantParams1->_dst_quant;
                quantParams0->_dst_quant = quantParams1->_dst_quant;
                sourceQuantParams = quantParams1;
            } else if (in1LayerInfo.isActivation()) {
                quantParams1->_weights_quant = quantParams0->_dst_quant;
                quantParams1->_dst_quant = quantParams0->_dst_quant;
                sourceQuantParams = quantParams0;
            } else {
                THROW_GNA_EXCEPTION << "Concat quantization for this case need to be implemented!!! \n";
            }
303
        }
304 305 306 307

        if (!fp32eq(quantParams0->_dst_quant.scale, quantParams1->_dst_quant.scale) && concatIdxToUpdate == -1) {
            THROW_GNA_EXCEPTION << "layers entered into concat have different scale factors" << concatLayer->name;
        }
308 309 310 311

        quantData->_dst_quant.scale = sourceQuantParams->_dst_quant.scale;
        quantData->_src_quant.scale = sourceQuantParams->_dst_quant.scale;

312 313 314 315 316
        if (concatIdxToUpdate == -1) {
            return true;
        }

        auto destinationQuantParams = InferenceEngine::getInjectedData<QuantizedLayerParams>(*concatLayer);
317
        destinationQuantParams->_dst_quant.scale = sourceQuantParams->_dst_quant.scale;
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351

        CNNLayerPtr  restartedLayer;
        // making a link activation possible without extra layer if first input to concat not a parent / indirect parent of second input
        // using ufs - upper first search
        gnalog() << "[UFS] searching for quantizeable layer prior: "<< concatLayer->name << ", via " << concatIdxToUpdate << "\n";

        CNNNetDFS(CNNLayerPtr(concatLayer, [](CNNLayer*){}), [&restartedLayer, concatLayer](CNNLayerPtr layer) {
            gnalog() << "[UFS] from : "<< concatLayer->name <<" reached: " << layer->name;
            // found that direct input to concat is a indirect parent of align filter - so no link required
            auto info = LayerInfo(layer);
            if (!info.isWeightable() && !info.isActivation()) {
                gnalog() << "... skipped\n";
                return;
            }
            restartedLayer = layer;
            gnalog() << "... OK,  need requantize\n";
        }, true, [&restartedLayer, &concatLayer, &concatIdxToUpdate](InferenceEngine::CNNLayer* from) {
            // aborting UFS once found functional layer, and using only specified input of concat
            return make_upstream_order(restartedLayer == nullptr? from : nullptr,
                                       from == concatLayer ? concatIdxToUpdate : -1);
        });

        if (restartedLayer == nullptr) {
            THROW_GNA_EXCEPTION << "cannot requantize " << concatIdxToUpdate << "input to concat: " << concatLayer->name;
        }
        auto quantDataForConCatInput = InferenceEngine::getInjectedData<QuantizedLayerParams>(*restartedLayer);

        auto restarLayerInfo = LayerInfo(restartedLayer);
        if (restarLayerInfo.isActivation()) {
            // requantize activation by just changing it's output scale factor
            quantDataForConCatInput->_dst_quant.scale = sourceQuantParams->_dst_quant.scale;
        }

        result = ScaleFactorUpdateResult(restartedLayer.get());
352 353 354 355 356

        return true;
    }
};

357 358 359 360 361 362 363 364 365 366 367 368 369 370
template<>
class ScaleFactorPerLayer<InferenceEngine::WeightableLayer*> {
 private:
    float const _scale_reduction_50 = 0.50;
    float const _scale_reduction_45 = 0.45;
    float const _scale_reduction_40 = 0.40;
    float const _scale_reduction_35 = 0.35;

    uint16_t const _scale_change_req_threshold = 30;
    uint16_t const _scale_change_threshold_100 = 100;
    uint16_t const _scale_change_threshold_150 = 150;
    uint16_t const _scale_change_threshold_200 = 200;

 public:
371
    bool operator()(InferenceEngine::WeightableLayer *wl, int weightsSize, ScaleFactorUpdateResult &result) {
372 373 374 375 376 377 378 379 380 381 382
        if ( !wl ) {
            THROW_GNA_EXCEPTION << "Incorrect Weightable Layer pointer  \n";
        } else if (!wl->_weights) {
            THROW_GNA_EXCEPTION << "Incorrect weight value for " << wl->name << ":" << wl->type << "\n";
        }

        auto prevLayer = CNNNetPrevLayer(wl);
        auto quantDataForInputLayer =
            InferenceEngine::getInjectedData<QuantizedLayerParams>(*InferenceEngine::CNNNetPrevLayer(wl).get());

        auto quant = InferenceEngine::getInjectedData<QuantizedLayerParams>(*wl);
383
        quant->_src_quant.scale = quantDataForInputLayer->_dst_quant.scale;
384 385 386 387 388 389 390 391 392 393 394 395 396
        // TODO: pass 8 bits somehow
        if (quant->_weights_quant.scale == 1.0f) {
            size_t scaleRange = 0;
            if (weightsSize == 2) {
                scaleRange = MAX_VAL_2B_WEIGHT;
            } else if (weightsSize == 1) {
                scaleRange = MAX_VAL_1B_WEIGHT;
            } else {
                THROW_GNA_EXCEPTION << "Unsupported weights size of: " << weightsSize;
            }
            quant->_weights_quant.scale =
                ScaleFactorForQuantization(wl->_weights->buffer().as<float *>(), scaleRange, wl->_weights->size());

397 398 399 400 401 402 403 404
            if (wl->_biases) {
                quant->_bias_quant.scale = ScaleFactorForQuantization(wl->_biases->buffer().as<float *>(),
                                                                      MAX_VAL_4B_BIAS,
                                                                      wl->_biases->size());
                quant->_bias_quant.scale = std::min(quant->_weights_quant.scale * quant->_src_quant.scale, quant->_bias_quant.scale);
                quant->_weights_quant.scale = quant->_bias_quant.scale / quant->_src_quant.scale;
            }

405 406 407 408
            // TODO: findout why ???
            if (weightsSize == 1) {
                quant->_weights_quant.scale *= MAX_OUT_MULTIPLIER;
            }
409 410 411 412 413 414 415 416 417 418

            double weights_reducer = 1.0;
            auto conv = dynamic_cast<ConvolutionLayer*>(wl);
            if (conv) {
                auto dims = conv->insData.front().lock()->getDims();

                weights_reducer = MAX_VAL_2B_FEAT * scaleRange * dims[1] / std::numeric_limits<int32_t>::max();
                weights_reducer = std::max(1.0, weights_reducer);
            }
            quant->_weights_quant.scale /= weights_reducer;
419 420
        }

421

422 423 424 425 426 427 428
        double tmp_dst_quant_scale = quant->_weights_quant.scale * quantDataForInputLayer->_dst_quant.scale;

        if (weightsSize == 1 &&
            static_cast<uint64_t>(tmp_dst_quant_scale * quant->_src_quant.scale) >
                                                    static_cast<uint64_t>(std::numeric_limits<int32_t>::max()-1) * _scale_change_req_threshold) {
            gnawarn() << "Output scale for " << wl->name
                                            << " too large and are being reduced. Else saturations likely will happen \n";
429
            // reduce weight scale according experimental heuristic
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
            if (quant->_dst_quant.scale * quant->_src_quant.scale / std::numeric_limits<int32_t>::max() < _scale_change_threshold_100) {
                quant->_weights_quant.scale *= _scale_reduction_50;
                tmp_dst_quant_scale *= _scale_reduction_50;
            } else if (quant->_dst_quant.scale * quant->_src_quant.scale / std::numeric_limits<int32_t>::max() < _scale_change_threshold_150) {
                quant->_weights_quant.scale *= _scale_reduction_45;
                tmp_dst_quant_scale *= _scale_reduction_45;
            } else if (quant->_dst_quant.scale * quant->_src_quant.scale / std::numeric_limits<int32_t>::max() < _scale_change_threshold_200) {
                quant->_weights_quant.scale *= _scale_reduction_40;
                tmp_dst_quant_scale *= _scale_reduction_40;
            } else {
                quant->_weights_quant.scale *= _scale_reduction_35;
                tmp_dst_quant_scale *= _scale_reduction_35;
            }
        }

        quant->_dst_quant.scale = tmp_dst_quant_scale;

        return true;
    }
};

template<>
class ScaleFactorPerLayer<InferenceEngine::ScaleShiftLayer*> : public ScaleFactorPerLayer<InferenceEngine::WeightableLayer*> {
 public:
454 455
    bool operator()(InferenceEngine::WeightableLayer *wl, int weightsSize, ScaleFactorUpdateResult &result) {
        return ScaleFactorPerLayer<InferenceEngine::WeightableLayer*>::operator()(wl, 2, result);
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    }
};

/**
 * GNA convolutions cannot be quantized in int8, remove when library starts support that
 */
template<>
class ScaleFactorPerLayer<InferenceEngine::ConvolutionLayer*> : public ScaleFactorPerLayer<InferenceEngine::ScaleShiftLayer*> {
};


}  // namespace details

/**
 * @brief scale factor calculator will calculate only output scale factors for the layer
 * if scale factor propagation not possible, it will fall indicate a restart condition
 */
class ScaleFactorCalculator {
    using Cnt = std::vector<InferenceEngine::CNNLayerPtr>;
    Cnt  net;
    mutable Cnt::const_iterator idx;
    mutable bool needRestart = false;
    int weightsBytesSize;

 public:
481 482
    ScaleFactorCalculator(Cnt &net, int weightsBytesSize)
            : net(net), weightsBytesSize(weightsBytesSize) {
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
        idx = std::begin(this->net);
    }
    bool needToRestart() const {
        return needRestart;
    }
    bool allLayersProcessed() const {
        return idx == std::end(net);
    }
    std::vector<InferenceEngine::CNNLayerPtr> getStartLayers() const {
        return std::vector<InferenceEngine::CNNLayerPtr>(idx, std::end(net));
    }
    template<class T>
    bool operator()(T ptr) const {
        needRestart = false;
        details::ScaleFactorUpdateResult result;
498
        if (!details::ScaleFactorPerLayer<T>()(ptr, weightsBytesSize, result)) {
499 500 501 502 503 504 505 506 507 508 509 510 511
            return false;
        }
        if (result) {
            idx++;
            return true;
        }

        idx = std::find_if(net.begin(), net.end(), [&](InferenceEngine::CNNLayerPtr cnnLayer) {
            if (!result) {
                return result.restartLayer == cnnLayer.get();
            }
            return ptr == cnnLayer.get();
        });
512 513 514
        if (idx != net.end()) {
            idx++;
        }
515 516 517 518 519 520
        needRestart = true;
        return true;
    }
};

}  // namespace GNAPluginNS