mxnet-gluon-example.py 1.7 KB
Newer Older
L.S. Cook's avatar
L.S. Cook committed
1
# ******************************************************************************
2
# Copyright 2018-2019 Intel Corporation
L.S. Cook's avatar
L.S. Cook committed
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

import mxnet as mx

19
# Convert gluon model to a static model
L.S. Cook's avatar
L.S. Cook committed
20 21 22 23 24 25 26 27 28 29 30 31 32
from mxnet.gluon.model_zoo import vision
import time

batch_shape = (1, 3, 224, 224)

input_data = mx.nd.zeros(batch_shape)

resnet_gluon = vision.resnet50_v2(pretrained=True)
resnet_gluon.hybridize()
resnet_gluon.forward(input_data)
resnet_gluon.export('resnet50_v2')
resnet_sym, arg_params, aux_params = mx.model.load_checkpoint('resnet50_v2', 0)

33
# Load the model into nGraph as a static graph
L.S. Cook's avatar
L.S. Cook committed
34 35 36 37
model = resnet_sym.simple_bind(ctx=mx.cpu(), data=batch_shape, grad_req='null')
model.copy_params_from(arg_params, aux_params)

# To test the model's performance, we've provided this helpful code snippet
38
# customizable
L.S. Cook's avatar
L.S. Cook committed
39 40 41 42 43 44 45 46 47 48

dry_run = 5
num_batches = 100
for i in range(dry_run + num_batches):
   if i == dry_run:
       start_time = time.time()
   outputs = model.forward(data=input_data, is_train=False)
   for output in outputs:
       output.wait_to_read()
print("Average Latency = ", (time.time() - start_time)/num_batches * 1000, "ms")