setup.py 14.2 KB
Newer Older
1
# ******************************************************************************
2
# Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

17
from setuptools import setup, Extension
18 19 20 21
from setuptools.command.build_ext import build_ext
import sys
import setuptools
import os
22
import re
23 24
import distutils.ccompiler

25
__version__ = os.environ.get('NGRAPH_VERSION', '0.0.0-dev')
aslepko's avatar
aslepko committed
26
PYNGRAPH_ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
27
NGRAPH_DEFAULT_INSTALL_DIR = os.environ.get('HOME')
28
NGRAPH_ONNX_IMPORT_ENABLE = os.environ.get('NGRAPH_ONNX_IMPORT_ENABLE')
29

30 31 32 33 34 35

def find_ngraph_dist_dir():
    """Return location of compiled ngraph library home."""
    if os.environ.get('NGRAPH_CPP_BUILD_PATH'):
        ngraph_dist_dir = os.environ.get('NGRAPH_CPP_BUILD_PATH')
    else:
36
        ngraph_dist_dir = os.path.join(NGRAPH_DEFAULT_INSTALL_DIR, 'ngraph_dist')
37 38 39

    found = os.path.exists(os.path.join(ngraph_dist_dir, 'include/ngraph'))
    if not found:
40 41
        print('Cannot find nGraph library in {} make sure that '
              'NGRAPH_CPP_BUILD_PATH is set correctly'.format(ngraph_dist_dir))
42 43
        sys.exit(1)
    else:
44
        print('nGraph library found in {}'.format(ngraph_dist_dir))
45 46 47 48
        return ngraph_dist_dir


def find_pybind_headers_dir():
49
    """Return location of pybind11 headers."""
50 51 52
    if os.environ.get('PYBIND_HEADERS_PATH'):
        pybind_headers_dir = os.environ.get('PYBIND_HEADERS_PATH')
    else:
aslepko's avatar
aslepko committed
53
        pybind_headers_dir = os.path.join(PYNGRAPH_ROOT_DIR, 'pybind11')
54 55 56

    found = os.path.exists(os.path.join(pybind_headers_dir, 'include/pybind11'))
    if not found:
57 58
        print('Cannot find pybind11 library in {} make sure that '
              'PYBIND_HEADERS_PATH is set correctly'.format(pybind_headers_dir))
59 60
        sys.exit(1)
    else:
61
        print('pybind11 library found in {}'.format(pybind_headers_dir))
62 63 64 65
        return pybind_headers_dir


NGRAPH_CPP_DIST_DIR = find_ngraph_dist_dir()
66 67
PYBIND11_INCLUDE_DIR = find_pybind_headers_dir() + '/include'
NGRAPH_CPP_INCLUDE_DIR = NGRAPH_CPP_DIST_DIR + '/include'
68 69 70 71 72 73 74 75
if os.path.exists(NGRAPH_CPP_DIST_DIR + '/lib'):
    NGRAPH_CPP_LIBRARY_DIR = NGRAPH_CPP_DIST_DIR + '/lib'
elif os.path.exists(NGRAPH_CPP_DIST_DIR + '/lib64'):
    NGRAPH_CPP_LIBRARY_DIR = NGRAPH_CPP_DIST_DIR + '/lib64'
else:
    print('Cannot find library directory in {}, make sure that nGraph is installed '
          'correctly'.format(NGRAPH_CPP_DIST_DIR))
    sys.exit(1)
76

77 78 79 80 81
NGRAPH_CPP_LIBRARY_NAME = 'ngraph'
"""For some platforms OpenVINO adds 'd' suffix to library names in debug configuration"""
if len([fn for fn in os.listdir(NGRAPH_CPP_LIBRARY_DIR) if re.search('ngraphd', fn)]):
    NGRAPH_CPP_LIBRARY_NAME = 'ngraphd'

82

83 84 85 86 87 88 89 90 91 92 93
def parallelCCompile(
    self,
    sources,
    output_dir=None,
    macros=None,
    include_dirs=None,
    debug=0,
    extra_preargs=None,
    extra_postargs=None,
    depends=None,
):
94 95 96 97 98 99
    """Build sources in parallel.

    Reference link:
    http://stackoverflow.com/questions/11013851/speeding-up-build-process-with-distutils
    Monkey-patch for parallel compilation.
    """
100 101
    # those lines are copied from distutils.ccompiler.CCompiler directly
    macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
102
        output_dir, macros, include_dirs, sources, depends, extra_postargs)
103 104 105 106 107 108 109 110 111 112
    cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
    # parallel code
    import multiprocessing.pool

    def _single_compile(obj):
        try:
            src, ext = build[obj]
        except KeyError:
            return
        self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
113

114
    # convert to list, imap is evaluated on-demand
115 116
    pool = multiprocessing.pool.ThreadPool()
    list(pool.imap(_single_compile, objects))
117 118 119
    return objects


120
distutils.ccompiler.CCompiler.compile = parallelCCompile
121 122 123


def has_flag(compiler, flagname):
124 125 126 127
    """Check whether a flag is supported by the specified compiler.

    As of Python 3.6, CCompiler has a `has_flag` method.
    cf http://bugs.python.org/issue26689
128
    """
129
    import tempfile
130

131 132 133 134 135 136 137 138 139 140
    with tempfile.NamedTemporaryFile('w', suffix='.cpp') as f:
        f.write('int main (int argc, char **argv) { return 0; }')
        try:
            compiler.compile([f.name], extra_postargs=[flagname])
        except setuptools.distutils.errors.CompileError:
            return False
    return True


def cpp_flag(compiler):
141
    """Check and return the -std=c++11 compiler flag."""
142
    if has_flag(compiler, '-std=c++11'):
143 144
        return '-std=c++11'
    else:
145
        raise RuntimeError('Unsupported compiler -- C++11 support is needed!')
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177


sources = [
    'pyngraph/function.cpp',
    'pyngraph/serializer.cpp',
    'pyngraph/node.cpp',
    'pyngraph/shape.cpp',
    'pyngraph/strides.cpp',
    'pyngraph/coordinate_diff.cpp',
    'pyngraph/axis_set.cpp',
    'pyngraph/axis_vector.cpp',
    'pyngraph/coordinate.cpp',
    'pyngraph/pyngraph.cpp',
    'pyngraph/util.cpp',
    'pyngraph/ops/util/arithmetic_reduction.cpp',
    'pyngraph/ops/util/binary_elementwise_comparison.cpp',
    'pyngraph/ops/util/op_annotations.cpp',
    'pyngraph/ops/util/binary_elementwise_arithmetic.cpp',
    'pyngraph/ops/util/binary_elementwise_logical.cpp',
    'pyngraph/ops/util/regmodule_pyngraph_op_util.cpp',
    'pyngraph/ops/util/unary_elementwise_arithmetic.cpp',
    'pyngraph/ops/util/index_reduction.cpp',
    'pyngraph/ops/abs.cpp',
    'pyngraph/ops/acos.cpp',
    'pyngraph/ops/add.cpp',
    'pyngraph/ops/and.cpp',
    'pyngraph/ops/argmax.cpp',
    'pyngraph/ops/argmin.cpp',
    'pyngraph/ops/asin.cpp',
    'pyngraph/ops/atan.cpp',
    'pyngraph/ops/avg_pool.cpp',
    'pyngraph/ops/broadcast.cpp',
178
    'pyngraph/ops/broadcast_distributed.cpp',
179
    'pyngraph/ops/fused/clamp.cpp',
180 181 182 183 184 185 186
    'pyngraph/ops/concat.cpp',
    'pyngraph/ops/constant.cpp',
    'pyngraph/ops/convert.cpp',
    'pyngraph/ops/convolution.cpp',
    'pyngraph/ops/cos.cpp',
    'pyngraph/ops/cosh.cpp',
    'pyngraph/ops/ceiling.cpp',
187
    'pyngraph/ops/fused/depth_to_space.cpp',
188
    'pyngraph/ops/dequantize.cpp',
189 190
    'pyngraph/ops/divide.cpp',
    'pyngraph/ops/dot.cpp',
191
    'pyngraph/ops/fused/elu.cpp',
192 193
    'pyngraph/ops/equal.cpp',
    'pyngraph/ops/exp.cpp',
194
    'pyngraph/ops/fused/fake_quantize.cpp',
195
    'pyngraph/ops/floor.cpp',
196
    'pyngraph/ops/fused/gelu.cpp',
197
    'pyngraph/ops/fused/gemm.cpp',
198 199
    'pyngraph/ops/greater.cpp',
    'pyngraph/ops/greater_eq.cpp',
200
    'pyngraph/ops/fused/grn.cpp',
201
    'pyngraph/ops/fused/group_conv.cpp',
202
    'pyngraph/ops/fused/hard_sigmoid.cpp',
203 204 205 206 207 208 209 210 211 212
    'pyngraph/ops/less.cpp',
    'pyngraph/ops/less_eq.cpp',
    'pyngraph/ops/log.cpp',
    'pyngraph/ops/lrn.cpp',
    'pyngraph/ops/maximum.cpp',
    'pyngraph/ops/max.cpp',
    'pyngraph/ops/product.cpp',
    'pyngraph/ops/max_pool.cpp',
    'pyngraph/ops/minimum.cpp',
    'pyngraph/ops/multiply.cpp',
213
    'pyngraph/ops/fused/mvn.cpp',
214 215 216 217 218 219 220 221
    'pyngraph/ops/negative.cpp',
    'pyngraph/ops/not.cpp',
    'pyngraph/ops/not_equal.cpp',
    'pyngraph/ops/op.cpp',
    'pyngraph/ops/one_hot.cpp',
    'pyngraph/ops/or.cpp',
    'pyngraph/ops/pad.cpp',
    'pyngraph/ops/parameter.cpp',
222
    'pyngraph/ops/passthrough.cpp',
223
    'pyngraph/ops/power.cpp',
224
    'pyngraph/ops/fused/prelu.cpp',
225 226 227
    'pyngraph/ops/quantize.cpp',
    'pyngraph/ops/quantized_convolution.cpp',
    'pyngraph/ops/quantized_dot.cpp',
228 229 230 231 232
    'pyngraph/ops/regmodule_pyngraph_op.cpp',
    'pyngraph/ops/relu.cpp',
    'pyngraph/ops/replace_slice.cpp',
    'pyngraph/ops/reshape.cpp',
    'pyngraph/ops/reverse.cpp',
233
    'pyngraph/ops/fused/rnn_cell.cpp',
234
    'pyngraph/ops/fused/scale_shift.cpp',
235
    'pyngraph/ops/select.cpp',
236
    'pyngraph/ops/fused/shuffle_channels.cpp',
237 238 239 240
    'pyngraph/ops/sign.cpp',
    'pyngraph/ops/sin.cpp',
    'pyngraph/ops/sinh.cpp',
    'pyngraph/ops/slice.cpp',
241
    'pyngraph/ops/fused/space_to_depth.cpp',
242
    'pyngraph/ops/sqrt.cpp',
243 244
    'pyngraph/ops/fused/squared_difference.cpp',
    'pyngraph/ops/fused/squeeze.cpp',
245 246 247 248 249 250 251 252 253 254
    'pyngraph/ops/subtract.cpp',
    'pyngraph/ops/sum.cpp',
    'pyngraph/ops/tan.cpp',
    'pyngraph/ops/tanh.cpp',
    'pyngraph/ops/topk.cpp',
    'pyngraph/ops/allreduce.cpp',
    'pyngraph/ops/get_output_element.cpp',
    'pyngraph/ops/min.cpp',
    'pyngraph/ops/batch_norm.cpp',
    'pyngraph/ops/softmax.cpp',
255
    'pyngraph/ops/result.cpp',
256
    'pyngraph/ops/fused/unsqueeze.cpp',
257
    'pyngraph/runtime/backend.cpp',
258
    'pyngraph/runtime/executable.cpp',
259 260 261 262 263 264 265 266 267
    'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
    'pyngraph/runtime/tensor.cpp',
    'pyngraph/passes/manager.cpp',
    'pyngraph/passes/regmodule_pyngraph_passes.cpp',
    'pyngraph/types/element_type.cpp',
    'pyngraph/types/regmodule_pyngraph_types.cpp',
]

package_dir = {
268 269 270 271 272 273 274
    'ngraph': PYNGRAPH_ROOT_DIR + '/ngraph',
    'ngraph.utils': PYNGRAPH_ROOT_DIR + '/ngraph/utils',
    'ngraph.impl': PYNGRAPH_ROOT_DIR + '/ngraph/impl',
    'ngraph.impl.op': PYNGRAPH_ROOT_DIR + '/ngraph/impl/op',
    'ngraph.impl.op.util': PYNGRAPH_ROOT_DIR + '/ngraph/impl/op/util',
    'ngraph.impl.passes': PYNGRAPH_ROOT_DIR + '/ngraph/impl/passes',
    'ngraph.impl.runtime': PYNGRAPH_ROOT_DIR + '/ngraph/impl/runtime',
275 276 277 278 279 280 281 282 283 284
}
packages = [
    'ngraph',
    'ngraph.utils',
    'ngraph.impl',
    'ngraph.impl.op',
    'ngraph.impl.op.util',
    'ngraph.impl.passes',
    'ngraph.impl.runtime',
]
285

286
sources = [PYNGRAPH_ROOT_DIR + '/' + source for source in sources]
287

aslepko's avatar
aslepko committed
288
include_dirs = [PYNGRAPH_ROOT_DIR, NGRAPH_CPP_INCLUDE_DIR, PYBIND11_INCLUDE_DIR]
289

290
library_dirs = [NGRAPH_CPP_LIBRARY_DIR]
291

292
libraries = [NGRAPH_CPP_LIBRARY_NAME]
293 294

extra_compile_args = []
295 296
if NGRAPH_ONNX_IMPORT_ENABLE in ['TRUE', 'ON', True]:
    extra_compile_args.append('-DNGRAPH_ONNX_IMPORT_ENABLE')
297 298 299

extra_link_args = []

300 301 302 303
data_files = [
    (
        'lib',
        [
304
            NGRAPH_CPP_LIBRARY_DIR + '/' + library
305 306
            for library in os.listdir(NGRAPH_CPP_LIBRARY_DIR)
        ],
307 308 309 310
    ),
    (
        'licenses',
        [
311 312
            NGRAPH_CPP_DIST_DIR + '/licenses/' + license
            for license in os.listdir(NGRAPH_CPP_DIST_DIR + '/licenses')
313 314 315 316
        ],
    ),
    (
        '',
317
        [NGRAPH_CPP_DIST_DIR + '/LICENSE'],
318
    ),
319 320
]

321 322 323 324 325 326 327 328 329 330 331 332
if NGRAPH_ONNX_IMPORT_ENABLE in ['TRUE', 'ON', True]:
    onnx_sources = [
        'pyngraph/onnx_import/onnx_import.cpp',
    ]
    onnx_sources = [PYNGRAPH_ROOT_DIR + '/' + source for source in onnx_sources]
    sources = sources + onnx_sources

    package_dir['ngraph.impl.onnx_import'] = (
        PYNGRAPH_ROOT_DIR + '/ngraph/impl/onnx_import'
    )
    packages.append('ngraph.impl.onnx_import')

333 334 335 336 337
ext_modules = [
    Extension(
        '_pyngraph',
        sources=sources,
        include_dirs=include_dirs,
338
        define_macros=[('VERSION_INFO', __version__)],
339 340
        library_dirs=library_dirs,
        libraries=libraries,
341
        extra_compile_args=extra_compile_args,
342
        extra_link_args=extra_link_args,
343 344
        language='c++',
    ),
345 346
]

347

348 349 350 351 352 353 354 355 356
def add_platform_specific_link_args(link_args):
    """Add linker flags specific for actual OS."""
    if sys.platform.startswith('linux'):
        link_args += ['-Wl,-rpath,$ORIGIN/../..']
        link_args += ['-z', 'noexecstack']
        link_args += ['-z', 'relro']
        link_args += ['-z', 'now']
    elif sys.platform == 'darwin':
        link_args += ['-Wl,-rpath,@loader_path/../..']
357
        link_args += ['-stdlib=libc++']
358 359


360
class BuildExt(build_ext):
361 362 363 364 365 366 367 368
    """A custom build extension for adding compiler-specific options."""

    def _add_extra_compile_arg(self, flag, compile_args):
        """Return True if successfully added given flag to compiler args."""
        if has_flag(self.compiler, flag):
            compile_args += [flag]
            return True
        return False
369

370
    def build_extensions(self):
371
        """Build extension providing extra compiler flags."""
372 373
        if sys.platform == 'win32':
            raise RuntimeError('Unsupported platform: win32!')
374
        # -Wstrict-prototypes is not a valid option for c++
375
        try:
376
            self.compiler.compiler_so.remove('-Wstrict-prototypes')
377 378
        except (AttributeError, ValueError):
            pass
379 380
        for ext in self.extensions:
            ext.extra_compile_args += [cpp_flag(self.compiler)]
381 382 383 384 385 386 387 388 389

            if not self._add_extra_compile_arg('-fstack-protector-strong', ext.extra_compile_args):
                self._add_extra_compile_arg('-fstack-protector', ext.extra_compile_args)

            self._add_extra_compile_arg('-fvisibility=hidden', ext.extra_compile_args)
            self._add_extra_compile_arg('-flto', ext.extra_compile_args)
            self._add_extra_compile_arg('-fPIC', ext.extra_compile_args)
            add_platform_specific_link_args(ext.extra_link_args)

390
            ext.extra_compile_args += ['-Wformat', '-Wformat-security']
391
            ext.extra_compile_args += ['-O2', '-D_FORTIFY_SOURCE=2']
392 393
            if sys.platform == 'darwin':
                ext.extra_compile_args += ['-stdlib=libc++']
394 395 396
        build_ext.build_extensions(self)


aslepko's avatar
aslepko committed
397
with open(os.path.join(PYNGRAPH_ROOT_DIR, 'requirements.txt')) as req:
398
    requirements = req.read().splitlines()
399 400 401
    setup_requires = [
        item for item in requirements if item.strip().startswith('numpy')
    ]
402 403

setup(
404
    name='ngraph-core',
405
    description="nGraph - Intel's graph compiler and runtime for Neural Networks",
406
    version=__version__,
407
    author='Intel Corporation',
408
    author_email='intelnervana@intel.com',
409
    url='https://github.com/NervanaSystems/ngraph/',
410
    license='License :: OSI Approved :: Apache Software License',
411 412
    long_description=open(os.path.join(PYNGRAPH_ROOT_DIR, 'README.md')).read(),
    long_description_content_type='text/markdown',
413
    ext_modules=ext_modules,
414
    package_dir=package_dir,
415
    packages=packages,
416
    cmdclass={'build_ext': BuildExt},
417
    data_files=data_files,
418
    setup_requires=setup_requires,
419
    install_requires=requirements,
420
    zip_safe=False,
421
    extras_require={
422
        'plaidml': ['plaidml>=0.6.3'],
423
    },
424
)