setup.py 14.5 KB
Newer Older
1
# ******************************************************************************
2
# Copyright 2017-2020 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

17
from setuptools import setup, Extension
18 19 20 21
from setuptools.command.build_ext import build_ext
import sys
import setuptools
import os
22
import re
23 24
import distutils.ccompiler

25
__version__ = os.environ.get('NGRAPH_VERSION', '0.0.0-dev')
aslepko's avatar
aslepko committed
26
PYNGRAPH_ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
27
NGRAPH_DEFAULT_INSTALL_DIR = os.environ.get('HOME')
28
NGRAPH_ONNX_IMPORT_ENABLE = os.environ.get('NGRAPH_ONNX_IMPORT_ENABLE')
29
NGRAPH_PYTHON_DEBUG = os.environ.get('NGRAPH_PYTHON_DEBUG')
30

31 32 33 34 35 36

def find_ngraph_dist_dir():
    """Return location of compiled ngraph library home."""
    if os.environ.get('NGRAPH_CPP_BUILD_PATH'):
        ngraph_dist_dir = os.environ.get('NGRAPH_CPP_BUILD_PATH')
    else:
37
        ngraph_dist_dir = os.path.join(NGRAPH_DEFAULT_INSTALL_DIR, 'ngraph_dist')
38 39 40

    found = os.path.exists(os.path.join(ngraph_dist_dir, 'include/ngraph'))
    if not found:
41 42
        print('Cannot find nGraph library in {} make sure that '
              'NGRAPH_CPP_BUILD_PATH is set correctly'.format(ngraph_dist_dir))
43 44
        sys.exit(1)
    else:
45
        print('nGraph library found in {}'.format(ngraph_dist_dir))
46 47 48 49
        return ngraph_dist_dir


def find_pybind_headers_dir():
50
    """Return location of pybind11 headers."""
51 52 53
    if os.environ.get('PYBIND_HEADERS_PATH'):
        pybind_headers_dir = os.environ.get('PYBIND_HEADERS_PATH')
    else:
aslepko's avatar
aslepko committed
54
        pybind_headers_dir = os.path.join(PYNGRAPH_ROOT_DIR, 'pybind11')
55 56 57

    found = os.path.exists(os.path.join(pybind_headers_dir, 'include/pybind11'))
    if not found:
58 59
        print('Cannot find pybind11 library in {} make sure that '
              'PYBIND_HEADERS_PATH is set correctly'.format(pybind_headers_dir))
60 61
        sys.exit(1)
    else:
62
        print('pybind11 library found in {}'.format(pybind_headers_dir))
63 64 65 66
        return pybind_headers_dir


NGRAPH_CPP_DIST_DIR = find_ngraph_dist_dir()
67 68
PYBIND11_INCLUDE_DIR = find_pybind_headers_dir() + '/include'
NGRAPH_CPP_INCLUDE_DIR = NGRAPH_CPP_DIST_DIR + '/include'
69 70 71 72 73 74 75 76
if os.path.exists(NGRAPH_CPP_DIST_DIR + '/lib'):
    NGRAPH_CPP_LIBRARY_DIR = NGRAPH_CPP_DIST_DIR + '/lib'
elif os.path.exists(NGRAPH_CPP_DIST_DIR + '/lib64'):
    NGRAPH_CPP_LIBRARY_DIR = NGRAPH_CPP_DIST_DIR + '/lib64'
else:
    print('Cannot find library directory in {}, make sure that nGraph is installed '
          'correctly'.format(NGRAPH_CPP_DIST_DIR))
    sys.exit(1)
77

78 79 80 81 82
NGRAPH_CPP_LIBRARY_NAME = 'ngraph'
"""For some platforms OpenVINO adds 'd' suffix to library names in debug configuration"""
if len([fn for fn in os.listdir(NGRAPH_CPP_LIBRARY_DIR) if re.search('ngraphd', fn)]):
    NGRAPH_CPP_LIBRARY_NAME = 'ngraphd'

83

84 85 86 87 88 89 90 91 92 93 94
def parallelCCompile(
    self,
    sources,
    output_dir=None,
    macros=None,
    include_dirs=None,
    debug=0,
    extra_preargs=None,
    extra_postargs=None,
    depends=None,
):
95 96 97 98 99 100
    """Build sources in parallel.

    Reference link:
    http://stackoverflow.com/questions/11013851/speeding-up-build-process-with-distutils
    Monkey-patch for parallel compilation.
    """
101 102
    # those lines are copied from distutils.ccompiler.CCompiler directly
    macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
103
        output_dir, macros, include_dirs, sources, depends, extra_postargs)
104 105 106 107 108 109 110 111 112 113
    cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
    # parallel code
    import multiprocessing.pool

    def _single_compile(obj):
        try:
            src, ext = build[obj]
        except KeyError:
            return
        self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
114

115
    # convert to list, imap is evaluated on-demand
116 117
    pool = multiprocessing.pool.ThreadPool()
    list(pool.imap(_single_compile, objects))
118 119 120
    return objects


121
distutils.ccompiler.CCompiler.compile = parallelCCompile
122 123 124


def has_flag(compiler, flagname):
125 126 127 128
    """Check whether a flag is supported by the specified compiler.

    As of Python 3.6, CCompiler has a `has_flag` method.
    cf http://bugs.python.org/issue26689
129
    """
130
    import tempfile
131

132 133 134 135 136 137 138 139 140 141
    with tempfile.NamedTemporaryFile('w', suffix='.cpp') as f:
        f.write('int main (int argc, char **argv) { return 0; }')
        try:
            compiler.compile([f.name], extra_postargs=[flagname])
        except setuptools.distutils.errors.CompileError:
            return False
    return True


def cpp_flag(compiler):
142
    """Check and return the -std=c++11 compiler flag."""
143
    if has_flag(compiler, '-std=c++11'):
144 145
        return '-std=c++11'
    else:
146
        raise RuntimeError('Unsupported compiler -- C++11 support is needed!')
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178


sources = [
    'pyngraph/function.cpp',
    'pyngraph/serializer.cpp',
    'pyngraph/node.cpp',
    'pyngraph/shape.cpp',
    'pyngraph/strides.cpp',
    'pyngraph/coordinate_diff.cpp',
    'pyngraph/axis_set.cpp',
    'pyngraph/axis_vector.cpp',
    'pyngraph/coordinate.cpp',
    'pyngraph/pyngraph.cpp',
    'pyngraph/util.cpp',
    'pyngraph/ops/util/arithmetic_reduction.cpp',
    'pyngraph/ops/util/binary_elementwise_comparison.cpp',
    'pyngraph/ops/util/op_annotations.cpp',
    'pyngraph/ops/util/binary_elementwise_arithmetic.cpp',
    'pyngraph/ops/util/binary_elementwise_logical.cpp',
    'pyngraph/ops/util/regmodule_pyngraph_op_util.cpp',
    'pyngraph/ops/util/unary_elementwise_arithmetic.cpp',
    'pyngraph/ops/util/index_reduction.cpp',
    'pyngraph/ops/abs.cpp',
    'pyngraph/ops/acos.cpp',
    'pyngraph/ops/add.cpp',
    'pyngraph/ops/and.cpp',
    'pyngraph/ops/argmax.cpp',
    'pyngraph/ops/argmin.cpp',
    'pyngraph/ops/asin.cpp',
    'pyngraph/ops/atan.cpp',
    'pyngraph/ops/avg_pool.cpp',
    'pyngraph/ops/broadcast.cpp',
179
    'pyngraph/ops/broadcast_distributed.cpp',
180
    'pyngraph/ops/fused/clamp.cpp',
181 182 183 184 185 186 187
    'pyngraph/ops/concat.cpp',
    'pyngraph/ops/constant.cpp',
    'pyngraph/ops/convert.cpp',
    'pyngraph/ops/convolution.cpp',
    'pyngraph/ops/cos.cpp',
    'pyngraph/ops/cosh.cpp',
    'pyngraph/ops/ceiling.cpp',
188
    'pyngraph/ops/fused/depth_to_space.cpp',
189
    'pyngraph/ops/dequantize.cpp',
190 191
    'pyngraph/ops/divide.cpp',
    'pyngraph/ops/dot.cpp',
192
    'pyngraph/ops/fused/elu.cpp',
193 194
    'pyngraph/ops/equal.cpp',
    'pyngraph/ops/exp.cpp',
195
    'pyngraph/ops/fused/fake_quantize.cpp',
196
    'pyngraph/ops/floor.cpp',
197
    'pyngraph/ops/fused/gelu.cpp',
198
    'pyngraph/ops/fused/gemm.cpp',
199 200
    'pyngraph/ops/greater.cpp',
    'pyngraph/ops/greater_eq.cpp',
201
    'pyngraph/ops/fused/grn.cpp',
202
    'pyngraph/ops/fused/group_conv.cpp',
203
    'pyngraph/ops/fused/hard_sigmoid.cpp',
204 205 206 207 208 209 210 211 212 213
    'pyngraph/ops/less.cpp',
    'pyngraph/ops/less_eq.cpp',
    'pyngraph/ops/log.cpp',
    'pyngraph/ops/lrn.cpp',
    'pyngraph/ops/maximum.cpp',
    'pyngraph/ops/max.cpp',
    'pyngraph/ops/product.cpp',
    'pyngraph/ops/max_pool.cpp',
    'pyngraph/ops/minimum.cpp',
    'pyngraph/ops/multiply.cpp',
214
    'pyngraph/ops/fused/mvn.cpp',
215 216 217 218 219 220 221 222
    'pyngraph/ops/negative.cpp',
    'pyngraph/ops/not.cpp',
    'pyngraph/ops/not_equal.cpp',
    'pyngraph/ops/op.cpp',
    'pyngraph/ops/one_hot.cpp',
    'pyngraph/ops/or.cpp',
    'pyngraph/ops/pad.cpp',
    'pyngraph/ops/parameter.cpp',
223
    'pyngraph/ops/passthrough.cpp',
224
    'pyngraph/ops/power.cpp',
225
    'pyngraph/ops/fused/prelu.cpp',
226 227 228
    'pyngraph/ops/quantize.cpp',
    'pyngraph/ops/quantized_convolution.cpp',
    'pyngraph/ops/quantized_dot.cpp',
229 230 231 232 233
    'pyngraph/ops/regmodule_pyngraph_op.cpp',
    'pyngraph/ops/relu.cpp',
    'pyngraph/ops/replace_slice.cpp',
    'pyngraph/ops/reshape.cpp',
    'pyngraph/ops/reverse.cpp',
234
    'pyngraph/ops/fused/rnn_cell.cpp',
235
    'pyngraph/ops/fused/scale_shift.cpp',
236
    'pyngraph/ops/select.cpp',
237
    'pyngraph/ops/fused/shuffle_channels.cpp',
238 239 240 241
    'pyngraph/ops/sign.cpp',
    'pyngraph/ops/sin.cpp',
    'pyngraph/ops/sinh.cpp',
    'pyngraph/ops/slice.cpp',
242
    'pyngraph/ops/fused/space_to_depth.cpp',
243
    'pyngraph/ops/sqrt.cpp',
244 245
    'pyngraph/ops/fused/squared_difference.cpp',
    'pyngraph/ops/fused/squeeze.cpp',
246 247 248 249 250 251 252 253 254 255
    'pyngraph/ops/subtract.cpp',
    'pyngraph/ops/sum.cpp',
    'pyngraph/ops/tan.cpp',
    'pyngraph/ops/tanh.cpp',
    'pyngraph/ops/topk.cpp',
    'pyngraph/ops/allreduce.cpp',
    'pyngraph/ops/get_output_element.cpp',
    'pyngraph/ops/min.cpp',
    'pyngraph/ops/batch_norm.cpp',
    'pyngraph/ops/softmax.cpp',
256
    'pyngraph/ops/result.cpp',
257
    'pyngraph/ops/fused/unsqueeze.cpp',
258
    'pyngraph/runtime/backend.cpp',
259
    'pyngraph/runtime/executable.cpp',
260 261 262 263 264 265 266 267 268
    'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
    'pyngraph/runtime/tensor.cpp',
    'pyngraph/passes/manager.cpp',
    'pyngraph/passes/regmodule_pyngraph_passes.cpp',
    'pyngraph/types/element_type.cpp',
    'pyngraph/types/regmodule_pyngraph_types.cpp',
]

package_dir = {
269 270 271 272 273 274 275
    'ngraph': PYNGRAPH_ROOT_DIR + '/ngraph',
    'ngraph.utils': PYNGRAPH_ROOT_DIR + '/ngraph/utils',
    'ngraph.impl': PYNGRAPH_ROOT_DIR + '/ngraph/impl',
    'ngraph.impl.op': PYNGRAPH_ROOT_DIR + '/ngraph/impl/op',
    'ngraph.impl.op.util': PYNGRAPH_ROOT_DIR + '/ngraph/impl/op/util',
    'ngraph.impl.passes': PYNGRAPH_ROOT_DIR + '/ngraph/impl/passes',
    'ngraph.impl.runtime': PYNGRAPH_ROOT_DIR + '/ngraph/impl/runtime',
276 277 278 279 280 281 282 283 284 285
}
packages = [
    'ngraph',
    'ngraph.utils',
    'ngraph.impl',
    'ngraph.impl.op',
    'ngraph.impl.op.util',
    'ngraph.impl.passes',
    'ngraph.impl.runtime',
]
286

287
sources = [PYNGRAPH_ROOT_DIR + '/' + source for source in sources]
288

aslepko's avatar
aslepko committed
289
include_dirs = [PYNGRAPH_ROOT_DIR, NGRAPH_CPP_INCLUDE_DIR, PYBIND11_INCLUDE_DIR]
290

291
library_dirs = [NGRAPH_CPP_LIBRARY_DIR]
292

293
libraries = [NGRAPH_CPP_LIBRARY_NAME]
294 295

extra_compile_args = []
296 297
if NGRAPH_ONNX_IMPORT_ENABLE in ['TRUE', 'ON', True]:
    extra_compile_args.append('-DNGRAPH_ONNX_IMPORT_ENABLE')
298 299 300

extra_link_args = []

301 302 303 304
data_files = [
    (
        'lib',
        [
305
            NGRAPH_CPP_LIBRARY_DIR + '/' + library
306 307
            for library in os.listdir(NGRAPH_CPP_LIBRARY_DIR)
        ],
308 309 310 311
    ),
    (
        'licenses',
        [
312 313
            NGRAPH_CPP_DIST_DIR + '/licenses/' + license
            for license in os.listdir(NGRAPH_CPP_DIST_DIR + '/licenses')
314 315 316 317
        ],
    ),
    (
        '',
318
        [NGRAPH_CPP_DIST_DIR + '/LICENSE'],
319
    ),
320 321
]

322 323 324 325 326 327 328 329 330 331 332 333
if NGRAPH_ONNX_IMPORT_ENABLE in ['TRUE', 'ON', True]:
    onnx_sources = [
        'pyngraph/onnx_import/onnx_import.cpp',
    ]
    onnx_sources = [PYNGRAPH_ROOT_DIR + '/' + source for source in onnx_sources]
    sources = sources + onnx_sources

    package_dir['ngraph.impl.onnx_import'] = (
        PYNGRAPH_ROOT_DIR + '/ngraph/impl/onnx_import'
    )
    packages.append('ngraph.impl.onnx_import')

334 335 336 337 338
ext_modules = [
    Extension(
        '_pyngraph',
        sources=sources,
        include_dirs=include_dirs,
339
        define_macros=[('VERSION_INFO', __version__)],
340 341
        library_dirs=library_dirs,
        libraries=libraries,
342
        extra_compile_args=extra_compile_args,
343
        extra_link_args=extra_link_args,
344 345
        language='c++',
    ),
346 347
]

348

349 350 351 352 353 354 355 356 357
def add_platform_specific_link_args(link_args):
    """Add linker flags specific for actual OS."""
    if sys.platform.startswith('linux'):
        link_args += ['-Wl,-rpath,$ORIGIN/../..']
        link_args += ['-z', 'noexecstack']
        link_args += ['-z', 'relro']
        link_args += ['-z', 'now']
    elif sys.platform == 'darwin':
        link_args += ['-Wl,-rpath,@loader_path/../..']
358
        link_args += ['-stdlib=libc++']
359 360


361
class BuildExt(build_ext):
362 363 364 365 366 367 368 369
    """A custom build extension for adding compiler-specific options."""

    def _add_extra_compile_arg(self, flag, compile_args):
        """Return True if successfully added given flag to compiler args."""
        if has_flag(self.compiler, flag):
            compile_args += [flag]
            return True
        return False
370

371 372 373 374 375 376 377
    def add_debug_or_release_flags(self):
        """Return compiler flags for Release and Debug build types."""
        if NGRAPH_PYTHON_DEBUG in ['TRUE', 'ON', True]:
            return ['-O0', '-g']
        else:
            return ['-O2', '-D_FORTIFY_SOURCE=2']

378
    def build_extensions(self):
379
        """Build extension providing extra compiler flags."""
380 381
        if sys.platform == 'win32':
            raise RuntimeError('Unsupported platform: win32!')
382
        # -Wstrict-prototypes is not a valid option for c++
383
        try:
384
            self.compiler.compiler_so.remove('-Wstrict-prototypes')
385 386
        except (AttributeError, ValueError):
            pass
387 388
        for ext in self.extensions:
            ext.extra_compile_args += [cpp_flag(self.compiler)]
389 390 391 392 393 394 395 396 397

            if not self._add_extra_compile_arg('-fstack-protector-strong', ext.extra_compile_args):
                self._add_extra_compile_arg('-fstack-protector', ext.extra_compile_args)

            self._add_extra_compile_arg('-fvisibility=hidden', ext.extra_compile_args)
            self._add_extra_compile_arg('-flto', ext.extra_compile_args)
            self._add_extra_compile_arg('-fPIC', ext.extra_compile_args)
            add_platform_specific_link_args(ext.extra_link_args)

398
            ext.extra_compile_args += ['-Wformat', '-Wformat-security']
399 400
            ext.extra_compile_args += self.add_debug_or_release_flags()

401 402
            if sys.platform == 'darwin':
                ext.extra_compile_args += ['-stdlib=libc++']
403 404 405
        build_ext.build_extensions(self)


aslepko's avatar
aslepko committed
406
with open(os.path.join(PYNGRAPH_ROOT_DIR, 'requirements.txt')) as req:
407
    requirements = req.read().splitlines()
408 409 410
    setup_requires = [
        item for item in requirements if item.strip().startswith('numpy')
    ]
411 412

setup(
413
    name='ngraph-core',
414
    description="nGraph - Intel's graph compiler and runtime for Neural Networks",
415
    version=__version__,
416
    author='Intel Corporation',
417
    author_email='intelnervana@intel.com',
418
    url='https://github.com/NervanaSystems/ngraph/',
419
    license='License :: OSI Approved :: Apache Software License',
420 421
    long_description=open(os.path.join(PYNGRAPH_ROOT_DIR, 'README.md')).read(),
    long_description_content_type='text/markdown',
422
    ext_modules=ext_modules,
423
    package_dir=package_dir,
424
    packages=packages,
425
    cmdclass={'build_ext': BuildExt},
426
    data_files=data_files,
427
    setup_requires=setup_requires,
428
    install_requires=requirements,
429
    zip_safe=False,
430
    extras_require={
431
        'plaidml': ['plaidml>=0.6.3'],
432
    },
433
)