Commit 2942a32a authored by Robert Kimball's avatar Robert Kimball

initial commit

parents
BasedOnStyle: LLVM
IndentWidth: 4
UseTab: Never
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -4
AlignConsecutiveDeclarations: true
AlignConsecutiveAssignments: true
AlignTrailingComments: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BreakBeforeBraces: Allman
BreakConstructorInitializersBeforeComma: true
ColumnLimit: 100
CommentPragmas: '.*'
IndentCaseLabels: false
IndentWrappedFunctionNames: true
KeepEmptyLinesAtTheStartOfBlocks: false
NamespaceIndentation: All
PointerAlignment: Left
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SortIncludes: false
ReflowComments: true
# Compiled Object files
*.slo
*.lo
*.o
*.obj
*.pyc
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
# QT Creator
*.creator
*.config
*.files
*.includes
*.creator.user*
*.autosave
.DS_Store
@eaDir/
.d/
bin/
*.log
output/
*.png
*.jpg
*.mp2
*.mpg
*.cpio
*.wav
doc/source/generated
.cache/
nervana_aeon.egg-info/
# vim
*.swp
*.swo
# setup.py intermediate files
build/
# makeenv and test intermediate files
tmp/
# Apple
*.AppleDouble
config_args.txt
.nfs*
venv/
.vscode/
# Copyright 2017 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required (VERSION 2.8.11)
# cmake_policy(SET CMP0037 OLD)
# cmake_policy(SET CMP0042 OLD)
# Set this flag before project definition to avoid using other compiler by gtest
set(CMAKE_CXX_COMPILER "clang++")
project (ngraph)
set(NGRAPH_VERSION 1.0.0)
set(CMAKE_DISABLE_SOURCE_CHANGES ON)
set(CMAKE_DISABLE_IN_SOURCE_BUILD ON)
# set directory where the custom finders live
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/")
set(CMAKE_CXX_FLAGS "-O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=return-type")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=inconsistent-missing-override")
# whitelist errors here
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Weverything")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat-pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-global-constructors")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-switch-enum")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-gnu-zero-variadic-macro-arguments")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-undef")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-exit-time-destructors")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-disabled-macro-expansion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-documentation")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-covered-switch-default")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-warning-option")
# should remove these
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-duplicate-enum") # from numpy
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-used-but-marked-unused") # from sox
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++11-compat-deprecated-writable-strings")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-double-promotion")
set(Python_ADDITIONAL_VERSIONS 3.6 3.5 3.4)
find_package(PythonLibs)
find_package(PythonInterp)
if (PYTHONLIBS_FOUND)
find_package(NumPy)
if(NUMPY_FOUND)
set(PYTHON_FOUND true)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DPYTHON_FOUND")
include_directories(SYSTEM ${PYTHON_INCLUDE_DIR})
include_directories(SYSTEM ${NUMPY_INCLUDE_DIRS})
else()
message("Numpy not found, Python interface not included")
endif(NUMPY_FOUND)
else()
message("Python not found, Python interface not included")
set(PYTHON_LIBRARIES "")
endif (PYTHONLIBS_FOUND)
set(TEST_DATA_DIR ${CMAKE_CURRENT_SOURCE_DIR}/test/test_data)
add_subdirectory(src)
include_directories(src src/transformers)
add_subdirectory(test)
find_program(BREATHE_EXECUTABLE
NAMES breathe-apidoc
DOC "Path to breathe executable")
# Handle REQUIRED and QUIET arguments
# this will also set SPHINX_FOUND to true if SPHINX_EXECUTABLE exists
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Breathe
"Failed to locate breathe executable"
BREATHE_EXECUTABLE)
if(NOT PYTHON_EXECUTABLE)
if(NumPy_FIND_QUIETLY)
find_package(PythonInterp QUIET)
else()
find_package(PythonInterp)
endif()
endif()
if(PYTHON_EXECUTABLE)
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
"import numpy as n; print(n.get_include());"
RESULT_VARIABLE NUMPY_RESULT
OUTPUT_VARIABLE NUMPY_INCLUDE_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NUMPY_RESULT MATCHES 0)
set(NUMPY_FOUND true)
set(NUMPY_LIBRARIES "")
set(NUMPY_DEFINITIONS "")
endif(NUMPY_RESULT MATCHES 0)
endif(PYTHON_EXECUTABLE)
# CMake find_package() Module for Sphinx documentation generator
# http://sphinx-doc.org/
#
# Example usage:
#
# find_package(Sphinx)
#
# If successful the following variables will be defined
# SPHINX_FOUND
# SPHINX_EXECUTABLE
find_program(SPHINX_EXECUTABLE
NAMES sphinx-build sphinx-build2
DOC "Path to sphinx-build executable")
# Handle REQUIRED and QUIET arguments
# this will also set SPHINX_FOUND to true if SPHINX_EXECUTABLE exists
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Sphinx
"Failed to locate sphinx-build executable"
SPHINX_EXECUTABLE)
# Provide options for controlling different types of output
option(SPHINX_OUTPUT_HTML "Output standalone HTML files" ON)
option(SPHINX_OUTPUT_MAN "Output man pages" ON)
option(SPHINX_WARNINGS_AS_ERRORS "When building documentation treat warnings as errors" ON)
\ No newline at end of file
set (SRC
util.cpp
transformers/exop.cpp
transformers/op_graph.cpp
transformers/axes.cpp
transformers/mock_transformer.cpp
element_type.cpp
names.cpp
transformers/ndarray.cpp
strides.cpp
tree.cpp
)
# file(GLOB DEPLOY_HEADERS_ABS *.hpp)
# foreach(DEPLOY_HEADER ${DEPLOY_HEADERS_ABS})
# get_filename_component(FNAME ${DEPLOY_HEADER} NAME)
# set(DEPLOY_HEADERS ${DEPLOY_HEADERS} ${FNAME})
# endforeach()
# if (PYTHON_FOUND)
# set(SRC ${SRC} api.cpp)
# endif(PYTHON_FOUND)
# set(SETUP_PY_IN "${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in")
# set(SETUP_PY "${CMAKE_CURRENT_BINARY_DIR}/../setup.py")
# configure_file(${SETUP_PY_IN} ${SETUP_PY})
include_directories(.)
add_library(ngraph SHARED ${SRC})
# install(TARGETS aeon DESTINATION lib)
# install(FILES ${DEPLOY_HEADERS} DESTINATION include/aeon)
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include "element_type.hpp"
const ElementType element_type_float = ElementType(32, true, true, "float");
const ElementType element_type_int8_t = ElementType(8, false, true, "int8_t");
const ElementType element_type_int32_t = ElementType(32, false, true, "int32_t");
const ElementType element_type_int64_t = ElementType(64, false, true, "int64_t");
const ElementType element_type_uint8_t = ElementType(8, false, false, "int8_t");
const ElementType element_type_uint32_t = ElementType(32, false, false, "int32_t");
const ElementType element_type_uint64_t = ElementType(64, false, false, "int64_t");
std::map<std::string, ElementType> ElementType::m_element_list;
ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname)
: m_bitwidth{bitwidth}
, m_is_float{is_float}
, m_is_signed{is_signed}
, m_cname{cname}
{
assert(m_bitwidth % 8 == 0);
}
const std::string& ElementType::c_type_string() const
{
return m_cname;
}
bool ElementType::operator==(const ElementType& other) const
{
return m_bitwidth == other.m_bitwidth && m_is_float == other.m_is_float &&
m_is_signed == other.m_is_signed;
}
size_t ElementType::size() const
{
return std::ceil((float)m_bitwidth / 8.0);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//================================================================================================
// ElementType
//================================================================================================
#pragma once
#include <string>
#include <map>
class ElementType
{
public:
ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const
{
std::hash<std::string> h;
return h(m_cname);
}
bool operator==(const ElementType& other) const;
private:
static std::map<std::string, ElementType> m_element_list;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
};
extern const ElementType element_type_float;
extern const ElementType element_type_int8_t;
extern const ElementType element_type_int32_t;
extern const ElementType element_type_int64_t;
extern const ElementType element_type_uint8_t;
extern const ElementType element_type_uint32_t;
extern const ElementType element_type_uint64_t;
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "names.hpp"
size_t NameableValue::__counter = 0;
std::map<std::string, NameableValue> NameableValue::__all_names;
NameableValue::NameableValue(const std::string& name,
const std::string& graph_label_type,
const std::string& doc_string)
: m_name{name}
, m_doc_string{doc_string}
{
auto glt = m_name;
if (graph_label_type.size() > 0)
{
glt = graph_label_type;
}
{
std::stringstream ss;
ss << glt << "[" << m_name << "]";
m_graph_label = ss.str();
}
}
const std::string& NameableValue::graph_label()
{
return m_graph_label;
}
const std::string& NameableValue::name()
{
return m_name;
}
void NameableValue::name(const std::string& name)
{
// if name == type(self).__name__ or name in NameableValue.__all_names:
// while True:
// c_name = "{}_{}".format(name, type(self).__counter)
// if c_name not in NameableValue.__all_names:
// name = c_name
// break
// type(self).__counter += 1
// NameableValue.__all_names[name] = self
// self.__name = name
}
const std::string& NameableValue::short_name()
{
// sn = self.name.split('_')[0]
// if sn.find('.') != -1:
// sn = sn.split('.')[1]
// return sn
static const std::string x = "unimplemented";
return x;
}
NameableValue& NameableValue::named(const std::string& name)
{
m_name = name;
return *this;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <string>
#include <map>
//================================================================================================
// NameableValue
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
//
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// calculate their lengths.
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
//
// TODO: Please add to the list...
//
// Arguments:
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
class NameableValue
{
public:
//!-----------------------------------------------------------------------------------
//! NameableValue
//! An object that can be named.
//!
//! Arguments:
//! graph_label_type: A label that should be used when drawing the graph. Defaults to
//! the class name.
//! name (str): The name of the object.
//! **kwargs: Parameters for related classes.
//!
//! Attributes:
//! graph_label_type: A label that should be used when drawing the graph.
//! id: Unique id for this object.
//!-----------------------------------------------------------------------------------
NameableValue(const std::string& name,
const std::string& graph_label_type = "",
const std::string& doc_string = "");
//!-----------------------------------------------------------------------------------
//! graph_label
//! The label used for drawings of the graph.
//!-----------------------------------------------------------------------------------
const std::string& graph_label();
//!-----------------------------------------------------------------------------------
//! name
//! Sets the object name to a unique name based on name.
//!
//! Arguments:
//! name: Prefix for the name
//!-----------------------------------------------------------------------------------
const std::string& name();
//!-----------------------------------------------------------------------------------
//! name
//!-----------------------------------------------------------------------------------
void name(const std::string& name);
//!-----------------------------------------------------------------------------------
//! short_name
//!-----------------------------------------------------------------------------------
const std::string& short_name();
//!-----------------------------------------------------------------------------------
//! named
//!-----------------------------------------------------------------------------------
NameableValue& named(const std::string& name);
static size_t __counter;
static std::map<std::string, NameableValue> __all_names;
std::string m_name;
std::string m_graph_label;
std::string m_short_name;
std::string m_doc_string;
};
#include <iostream>
#include <algorithm>
#include "strides.hpp"
#include "util.hpp"
using namespace std;
//================================================================================================
//
//================================================================================================
ngraph::tensor_size::tensor_size()
: m_tree{}
, m_element_type{element_type_float}
{
}
ngraph::tensor_size::tensor_size(size_t s, ElementType et)
: m_tree{s}
, m_element_type{et}
{
}
ngraph::tensor_size::tensor_size(const std::initializer_list<scalar_tree>& list, ElementType et)
: m_tree{list}
, m_element_type{et}
{
}
ngraph::tensor_size::tensor_size(const std::vector<size_t>& list, const ElementType& et)
: m_tree{list}
, m_element_type{et}
{
}
ngraph::tensor_stride ngraph::tensor_size::full_strides() const
{
tensor_stride result{*this};
vector<size_t*> value_pointer_list;
vector<size_t> size_list;
scalar_tree::traverse_tree(result.m_tree, [&](size_t* value) {
value_pointer_list.push_back(value);
size_list.push_back(*value);
});
int index = value_pointer_list.size() - 1;
*value_pointer_list[index] = result.m_element_type.size();
for (index--; index >= 0; index--)
{
*value_pointer_list[index] = *value_pointer_list[index + 1] * size_list[index + 1];
}
return result;
}
ngraph::tensor_stride ngraph::tensor_size::strides() const
{
return full_strides().strides();
}
ngraph::tensor_size ngraph::tensor_size::sizes() const
{
vector<size_t> tmp;
if (m_tree.is_list())
{
for (auto s : m_tree.get_list())
{
tmp.push_back(s.reduce([](size_t a, size_t b) { return a * b; }));
}
}
else
{
tmp.push_back(m_tree.get_value());
}
return tensor_size(tmp, m_element_type);
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::tensor_size& s)
{
out << s.m_tree;
return out;
}
//================================================================================================
//
//================================================================================================
ngraph::tensor_stride::tensor_stride()
: m_tree{}
, m_element_type{element_type_float}
{
}
ngraph::tensor_stride::tensor_stride(const tensor_size& s)
: m_tree{}
, m_element_type{s.m_element_type}
{
m_tree = s.m_tree;
}
ngraph::tensor_stride::tensor_stride(const std::vector<size_t>& list, const ElementType& et)
: m_tree{}
, m_element_type{et}
{
m_tree = list;
}
ngraph::tensor_stride ngraph::tensor_stride::reduce_strides() const
{
vector<size_t> tmp;
if (m_tree.is_list())
{
for (auto s : m_tree.get_list())
{
tmp.push_back(s.reduce([](size_t a, size_t b) { return min(a, b); }));
}
}
else
{
tmp.push_back(m_tree.get_value());
}
return tensor_stride(tmp, m_element_type);
}
ngraph::tensor_stride ngraph::tensor_stride::full_strides() const
{
return *this;
}
ngraph::tensor_stride ngraph::tensor_stride::strides() const
{
return reduce_strides();
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::tensor_stride& s)
{
out << s.m_tree;
return out;
}
#pragma once
#include <cstdio>
#include <vector>
#include <initializer_list>
#include "element_type.hpp"
#include "tree.hpp"
namespace ngraph
{
class tensor_size;
class tensor_stride;
}
//================================================================================================
//
//================================================================================================
class ngraph::tensor_size
{
friend class tensor_stride;
public:
tensor_size();
tensor_size(size_t s, ElementType et = element_type_float);
tensor_size(const std::initializer_list<scalar_tree>& list,
ElementType et = element_type_float);
const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const;
tensor_stride strides() const;
tensor_size sizes() const;
tensor_size operator[](size_t index) const;
friend std::ostream& operator<<(std::ostream& out, const tensor_size& s);
private:
tensor_size(const std::vector<size_t>&, const ElementType&);
scalar_tree m_tree;
ElementType m_element_type;
};
//================================================================================================
//
//================================================================================================
class ngraph::tensor_stride
{
friend class tensor_size;
public:
tensor_stride();
const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const;
tensor_stride strides() const;
tensor_stride reduce_strides() const;
tensor_stride operator[](size_t index) const;
friend std::ostream& operator<<(std::ostream& out, const tensor_stride& s);
private:
tensor_stride(const tensor_size&);
tensor_stride(const std::vector<size_t>&, const ElementType&);
scalar_tree m_tree;
ElementType m_element_type;
};
import ngraph as ng
C, H, N = ng.make_axis(5), ng.make_axis(3), ng.make_axis(7)
a = ng.Axes(axes=[C, H, N])
b = ng.Axes(axes=[[C, H], N])
print('a={}'.format(a))
print('b={}'.format(b))
print('a[0]={}'.format(a[0]))
print('a[1]={}'.format(a[1]))
print('a[2]={}'.format(a[2]))
print('b[0]={}'.format(b[0]))
print('b[1]={}'.format(b[1]))
print('as_nested_list(a)={}'.format(ng.Axes.as_nested_list(a)))
print('as_flattened_list(a)={}'.format(ng.Axes.as_flattened_list(a)))
print('as_nested_list(b)={}'.format(ng.Axes.as_nested_list(b)))
print('as_flattened_list(b)={}'.format(ng.Axes.as_flattened_list(b)))
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <type_traits>
#include <sstream>
#include "element_type.hpp"
class ExecutionState;
class Op;
// class TensorDescription;
class ComputationOp;
using computation_op_ptr = std::shared_ptr<ComputationOp>;
using op_ptr = std::shared_ptr<Op>;
using scalar_t = float;
//================================================================================================
// TensorInterface
//================================================================================================
class TensorInterface
{
public:
virtual ~TensorInterface() {}
virtual const ElementType& element_type() const = 0;
virtual std::string value_string() const = 0;
};
//================================================================================================
// Tensor
//================================================================================================
template <typename T>
class Tensor : public TensorInterface
{
public:
Tensor(const T& val)
: m_value{val}
, m_element_type{element_type_float}
{
}
virtual ~Tensor() {}
const ElementType& element_type() const override { return m_element_type; }
std::string value_string() const override
{
std::string rc = "WTF";
if (std::is_floating_point<T>::value)
{
std::stringstream ss;
ss << m_value;
rc = ss.str();
}
return rc;
}
private:
T m_value;
ElementType m_element_type;
};
//================================================================================================
// Transformer
//================================================================================================
class Transformer
{
public:
virtual ~Transformer() {}
virtual ExecutionState& execution_state() = 0;
};
//================================================================================================
// TensorDescription
//================================================================================================
// class TensorDescription
// {
// public:
// virtual ~TensorDescription();
// virtual axes_key_t axes_key() const = 0;
// virtual std::string name() const = 0;
// virtual std::vector<size_t> shape() const = 0;
// virtual std::shared_ptr<TensorDescription> base() = 0;
// virtual ElementType element_type() const = 0;
// virtual size_t tensor_size() = 0;
// virtual bool is_persistent() = 0;
// virtual bool is_input() = 0;
// };
//================================================================================================
// Op
//================================================================================================
// class Op
// {
// // Any operation that can be in an AST.
// // Arguments:
// // args: Values used by this node.
// // const: The value of a constant Op, or None,
// // constant (bool): The Op is constant. Default False.
// // forward: If not None, the node to use instead of this node.
// // metadata: String key value dictionary for frontend metadata.
// // kwargs: Args defined in related classes.
// // Attributes:
// // const: The value of a constant.
// // constant (bool): The value is constant.
// // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // persistent (bool): The value will be retained from computation to computation and
// // not shared. Always True if reference is set.
// // metadata: Dictionary with of string keys and values used for attaching
// // arbitrary metadata to nodes.
// // trainable: The value is trainable.
// public:
// virtual ~Op() {}
// virtual std::string name() const = 0;
// virtual tensor_description_ptr tensor_description() = 0;
// virtual op_ptr tensor() = 0;
// virtual bool is_tensor_op() = 0;
// virtual bool is_state_op() const = 0;
// virtual bool is_sequencing_op() const = 0;
// virtual op_ptr effective_tensor_op() = 0;
// virtual const std::vector<op_ptr>& all_deps() const = 0;
// // ops
// // TODO support multiple types
// static op_ptr constant(float value)
// {
// op_ptr = make_shared<LiteralScalarOp>(value);
// }
// };
//================================================================================================
// TensorOp
//================================================================================================
// class TensorOp : public Op
// {
// public:
// std::string name() const override { return "TensorOp"; }
// tensor_description_ptr tensor_description() override { return nullptr; }
// op_ptr tensor() override { return nullptr; }
// bool is_tensor_op() override { return true; }
// bool is_state_op() const override { return false; }
// op_ptr effective_tensor_op() override { return nullptr; }
// const std::vector<op_ptr>& all_deps() const override { return m_all_deps; }
// private:
// std::vector<op_ptr> m_all_deps;
// };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "mock_transformer.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "mock.hpp"
#include "exop.hpp"
//================================================================================================
// CpuTransformer
//================================================================================================
class CpuTransformer : public Transformer
{
public:
virtual ~CpuTransformer() {}
ExecutionState& execution_state() override { return m_execution_state; }
private:
ExecutionState m_execution_state;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ndarray.hpp"
ngraph::ndarray::ndarray(ElementType _dtype,
std::vector<size_t> _shape,
std::shared_ptr<char> _buffer,
size_t _offset,
const tensor_stride& _strides)
: dtype{_dtype}
, shape{_shape}
, buffer{_buffer}
, strides{_strides}
, offset{_offset}
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include <memory>
#include "element_type.hpp"
#include "strides.hpp"
namespace ngraph
{
class ndarray;
}
class ngraph::ndarray
{
public:
ndarray(ElementType dtype = element_type_float,
std::vector<size_t> shape = std::vector<size_t>(),
std::shared_ptr<char> data = nullptr,
size_t offset = 0,
const tensor_stride& strides = tensor_stride());
ElementType dtype;
std::vector<size_t> shape;
std::shared_ptr<char> buffer;
tensor_stride strides;
size_t offset;
};
This diff is collapsed.
This diff is collapsed.
#include "tree.hpp"
#include "util.hpp"
//================================================================================================
//
//================================================================================================
#pragma once
#include <vector>
#include <initializer_list>
#include <iostream>
#include <algorithm>
#include "util.hpp"
namespace ngraph
{
template <typename T>
class tree;
using scalar_tree = ngraph::tree<size_t>;
}
//================================================================================================
//
//================================================================================================
template <typename T>
class ngraph::tree
{
public:
tree(T s)
: m_list{}
, m_value{s}
, m_is_list{false}
{
}
tree(const std::initializer_list<tree<T>>& list)
: m_list{}
, m_value{0}
, m_is_list{true}
{
m_list = list;
}
tree(const std::vector<T>& list)
: m_list{}
, m_value{0}
, m_is_list{true}
{
for (auto s : list)
{
m_list.push_back(tree(s));
}
}
bool is_list() const { return m_is_list; }
T get_value() const { return m_value; }
const std::vector<tree>& get_list() const { return m_list; }
static void traverse_tree(tree& s, std::function<void(T*)> func)
{
if (s.is_list())
{
for (tree& s1 : s.m_list)
{
traverse_tree(s1, func);
}
}
else
{
func(&(s.m_value));
}
}
friend std::ostream& operator<<(std::ostream& out, const tree& s)
{
if (s.is_list())
{
out << "(" << join(s.get_list(), ", ") << ")";
}
else
{
out << s.get_value();
}
return out;
}
T reduce(const std::function<T(T, T)>& func) const
{
size_t rc;
if (is_list())
{
switch (m_list.size())
{
case 0: rc = 0; break;
case 1: rc = m_list[0].reduce(func); break;
default:
rc = m_list[0].reduce(func);
for (int i = 1; i < m_list.size(); i++)
{
rc = func(rc, m_list[i].reduce(func));
}
break;
}
}
else
{
rc = m_value;
}
return rc;
}
private:
std::vector<tree> m_list;
T m_value;
bool m_is_list;
};
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2017 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Enable ExternalProject CMake module
include(ExternalProject)
# Download and install GoogleTest
ExternalProject_Add(
gtest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.8.0
PREFIX ${CMAKE_CURRENT_BINARY_DIR}/gtest
# Disable install step
INSTALL_COMMAND ""
UPDATE_COMMAND ""
)
# Get GTest source and binary directories from CMake project
ExternalProject_Get_Property(gtest source_dir binary_dir)
# Create a libgtest target to be used as a dependency by test programs
add_library(libgtest IMPORTED STATIC GLOBAL)
add_dependencies(libgtest gtest)
# Set libgtest properties
set_target_properties(libgtest PROPERTIES
"IMPORTED_LOCATION" "${binary_dir}/googlemock/gtest/libgtest.a"
"IMPORTED_LINK_INTERFACE_LIBRARIES" "${CMAKE_THREAD_LIBS_INIT}"
)
# I couldn't make it work with INTERFACE_INCLUDE_DIRECTORIES
include_directories(SYSTEM "${source_dir}/googletest/include")
set (SRC
main.cpp
util.cpp
tensor.cpp
exop.cpp
axes.cpp
element_type.cpp
op.cpp
uuid.cpp
names.cpp
strides.cpp
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCURDIR=\\\"${CMAKE_CURRENT_SOURCE_DIR}\\\"")
add_executable(test ${SRC})
target_link_libraries(test ngraph pthread libgtest)
add_dependencies(test ngraph libgtest)
add_custom_target(runtest
COMMAND ${PROJECT_BINARY_DIR}/test/test
DEPENDS test)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment