Commit 8c16125d authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 'master' into cyphers/view

parents 0064cfd0 8d57ce68
...@@ -44,3 +44,10 @@ SpacesInSquareBrackets: false ...@@ -44,3 +44,10 @@ SpacesInSquareBrackets: false
SortIncludes: false SortIncludes: false
ReflowComments: true ReflowComments: true
IncludeCategories:
- Regex: '^".*'
Priority: 3
- Regex: '^<.*'
Priority: 2
SortIncludes: true
...@@ -27,7 +27,10 @@ const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false, ...@@ -27,7 +27,10 @@ const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false,
std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list; std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list;
ngraph::ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname) ngraph::ElementType::ElementType(size_t bitwidth,
bool is_float,
bool is_signed,
const std::string& cname)
: m_bitwidth{bitwidth} : m_bitwidth{bitwidth}
, m_is_float{is_float} , m_is_float{is_float}
, m_is_signed{is_signed} , m_is_signed{is_signed}
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#pragma once #pragma once
#include <string>
#include <map> #include <map>
#include <string>
namespace ngraph namespace ngraph
{ {
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
*/ */
#include <chrono> #include <chrono>
#include <condition_variable>
#include <ctime>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <ctime>
#include <thread>
#include <mutex> #include <mutex>
#include <condition_variable> #include <thread>
#include "log.hpp" #include "log.hpp"
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#pragma once #pragma once
#include <deque>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <deque>
namespace nervana namespace nervana
{ {
......
...@@ -14,40 +14,39 @@ ...@@ -14,40 +14,39 @@
#pragma once #pragma once
#include <string>
#include <map> #include <map>
#include <string>
namespace ngraph namespace ngraph
{ {
//================================================================================================
//================================================================================================ // NameableValue
// NameableValue // An Axis labels a dimension of a tensor. The op-graph uses
// An Axis labels a dimension of a tensor. The op-graph uses // the identity of Axis objects to pair and specify dimensions in
// the identity of Axis objects to pair and specify dimensions in // symbolic expressions. This system has several advantages over
// symbolic expressions. This system has several advantages over // using the length and position of the axis as in other frameworks:
// using the length and position of the axis as in other frameworks: //
// // 1) Convenience. The dimensions of tensors, which may be nested
// 1) Convenience. The dimensions of tensors, which may be nested // deep in a computation graph, can be specified without having to
// deep in a computation graph, can be specified without having to // calculate their lengths.
// calculate their lengths. //
// // 2) Safety. Axis labels are analogous to types in general-purpose
// 2) Safety. Axis labels are analogous to types in general-purpose // programming languages, allowing objects to interact only when
// programming languages, allowing objects to interact only when // they are permitted to do so in advance. In symbolic computation,
// they are permitted to do so in advance. In symbolic computation, // this prevents interference between axes that happen to have the
// this prevents interference between axes that happen to have the // same lengths but are logically distinct, e.g. if the number of
// same lengths but are logically distinct, e.g. if the number of // training examples and the number of input features are both 50.
// training examples and the number of input features are both 50. //
// // TODO: Please add to the list...
// TODO: Please add to the list... //
// // Arguments:
// Arguments: // length: The length of the axis.
// length: The length of the axis. // batch: Whether the axis is a batch axis.
// batch: Whether the axis is a batch axis. // recurrent: Whether the axis is a recurrent axis.
// recurrent: Whether the axis is a recurrent axis. //================================================================================================
//================================================================================================ class NameableValue
class NameableValue {
{ public:
public:
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
//! NameableValue //! NameableValue
//! An object that can be named. //! An object that can be named.
...@@ -103,7 +102,6 @@ public: ...@@ -103,7 +102,6 @@ public:
std::string m_graph_label; std::string m_graph_label;
std::string m_short_name; std::string m_short_name;
std::string m_doc_string; std::string m_doc_string;
}; };
} // end namespace ngraph } // end namespace ngraph
#include <iostream>
#include <algorithm> #include <algorithm>
#include <iostream>
#include "strides.hpp" #include "strides.hpp"
#include "util.hpp" #include "util.hpp"
......
#pragma once #pragma once
#include <cstdio> #include <cstdio>
#include <vector>
#include <initializer_list> #include <initializer_list>
#include <vector>
#include "element_type.hpp" #include "element_type.hpp"
#include "tree.hpp" #include "tree.hpp"
...@@ -27,7 +27,6 @@ public: ...@@ -27,7 +27,6 @@ public:
ElementType et = element_type_float); ElementType et = element_type_float);
const ElementType& get_type() const { return m_element_type; } const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const; tensor_stride full_strides() const;
tensor_stride strides() const; tensor_stride strides() const;
tensor_size sizes() const; tensor_size sizes() const;
...@@ -53,7 +52,6 @@ class ngraph::tensor_stride ...@@ -53,7 +52,6 @@ class ngraph::tensor_stride
public: public:
tensor_stride(); tensor_stride();
const ElementType& get_type() const { return m_element_type; } const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const; tensor_stride full_strides() const;
tensor_stride strides() const; tensor_stride strides() const;
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <cmath>
#include <cassert>
#include "axes.hpp" #include "axes.hpp"
#include "util.hpp" #include "util.hpp"
......
...@@ -14,130 +14,130 @@ ...@@ -14,130 +14,130 @@
#pragma once #pragma once
#include <vector>
#include <string>
#include <memory>
#include <limits>
#include <initializer_list> #include <initializer_list>
#include <limits>
#include <memory>
#include <set> #include <set>
#include <string>
#include <vector>
#include "uuid.hpp"
#include "element_type.hpp" #include "element_type.hpp"
#include "names.hpp" #include "names.hpp"
#include "util.hpp"
#include "strides.hpp" #include "strides.hpp"
#include "util.hpp"
#include "uuid.hpp"
#include "uuid.hpp" #include "uuid.hpp"
namespace ngraph namespace ngraph
{ {
class Axes; class Axes;
class Axis; class Axis;
class FlattenedAxis; class FlattenedAxis;
class TensorDescription; class TensorDescription;
class Op; class Op;
using op_ptr = std::shared_ptr<Op>; using op_ptr = std::shared_ptr<Op>;
using tensor_description_ptr = std::shared_ptr<TensorDescription>; using tensor_description_ptr = std::shared_ptr<TensorDescription>;
using axes_key_t = size_t; using axes_key_t = size_t;
class slice class slice
{ {
public: public:
slice(int64_t start = -1, int64_t stop = -1, int64_t step = 1); slice(int64_t start = -1, int64_t stop = -1, int64_t step = 1);
size_t sliced_length(size_t length) const; size_t sliced_length(size_t length) const;
private: private:
size_t m_start; size_t m_start;
size_t m_stop; size_t m_stop;
int64_t m_step; int64_t m_step;
}; };
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// default_dtype // default_dtype
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def default_dtype(dtype=None): // def default_dtype(dtype=None):
// if dtype is None: // if dtype is None:
// dtype = np.dtype(np.float32) // dtype = np.dtype(np.float32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype): // elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try: // try:
// dtype = np.dtype(dtype) // dtype = np.dtype(dtype)
// except TypeError: // except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype)) // raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype // return dtype
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// default_int_dtype // default_int_dtype
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def default_int_dtype(dtype=None): // def default_int_dtype(dtype=None):
// if dtype is None: // if dtype is None:
// dtype = np.dtype(np.int32) // dtype = np.dtype(np.int32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype): // elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try: // try:
// dtype = np.dtype(dtype) // dtype = np.dtype(dtype)
// except TypeError: // except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype)) // raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype // return dtype
//================================================================================================ //================================================================================================
// make_axis // make_axis
// Returns a new Axis. // Returns a new Axis.
// //
// Args: // Args:
// length (int, optional): Length of the axis. // length (int, optional): Length of the axis.
// name (String, optional): Name of the axis. // name (String, optional): Name of the axis.
// batch (bool, optional): This is a batch axis. Defaults to False. // batch (bool, optional): This is a batch axis. Defaults to False.
// recurrent (bool, optional): This is a recurrent axis. Defaults to False. // recurrent (bool, optional): This is a recurrent axis. Defaults to False.
// docstring (String, optional): A docstring for the axis. // docstring (String, optional): A docstring for the axis.
// //
// Returns: // Returns:
// Axis: A new Axis. // Axis: A new Axis.
//================================================================================================ //================================================================================================
Axis make_axis(size_t length, Axis make_axis(size_t length,
const std::string& name = "", const std::string& name = "",
bool batch = false, bool batch = false,
bool recurrent = false); bool recurrent = false);
//================================================================================================ //================================================================================================
// make_axes // make_axes
// Makes an Axes object. // Makes an Axes object.
// //
// Args: // Args:
// axes: A list of Axis. // axes: A list of Axis.
// //
// Returns: // Returns:
// Axes: An Axes. // Axes: An Axes.
//================================================================================================ //================================================================================================
Axes make_axes(const std::vector<Axis>&); Axes make_axes(const std::vector<Axis>&);
//================================================================================================ //================================================================================================
// Axis // Axis
// An Axis labels a dimension of a tensor. The op-graph uses // An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in // the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over // symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks: // using the length and position of the axis as in other frameworks:
// //
// 1) Convenience. The dimensions of tensors, which may be nested // 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to // deep in a computation graph, can be specified without having to
// calculate their lengths. // calculate their lengths.
// //
// 2) Safety. Axis labels are analogous to types in general-purpose // 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when // programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation, // they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the // this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of // same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50. // training examples and the number of input features are both 50.
// //
// TODO: Please add to the list... // TODO: Please add to the list...
// //
// Arguments: // Arguments:
// length: The length of the axis. // length: The length of the axis.
// batch: Whether the axis is a batch axis. // batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis. // recurrent: Whether the axis is a recurrent axis.
//================================================================================================ //================================================================================================
class Axis class Axis
{ {
public: public:
Axis& operator+(const Axis& rhs); Axis& operator+(const Axis& rhs);
Axis& operator-(const Axis& rhs); Axis& operator-(const Axis& rhs);
...@@ -145,7 +145,6 @@ public: ...@@ -145,7 +145,6 @@ public:
Axis(size_t length, const std::string& new_name); Axis(size_t length, const std::string& new_name);
virtual ~Axis() {} virtual ~Axis() {}
void named(const std::string& new_name); void named(const std::string& new_name);
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
...@@ -232,99 +231,99 @@ public: ...@@ -232,99 +231,99 @@ public:
uuid_type uuid; uuid_type uuid;
size_t __length; size_t __length;
static size_t __name_counter; static size_t __name_counter;
}; };
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// _sliced_length // _sliced_length
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def _sliced_length(s, incoming_length): // def _sliced_length(s, incoming_length):
// start, stop, step = s.indices(incoming_length) // start, stop, step = s.indices(incoming_length)
// # max with 0 so we dont ever return a negative length. This // # max with 0 so we dont ever return a negative length. This
// # matches how python handles it internally. Raising an exception // # matches how python handles it internally. Raising an exception
// # might also be reasonable. // # might also be reasonable.
// if step == 1: // if step == 1:
// return max(stop - start, 0) // return max(stop - start, 0)
// elif step == -1: // elif step == -1:
// return max(start - stop, 0) // return max(start - stop, 0)
// else: // else:
// _validate_slice(s) // _validate_slice(s)
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// _validate_slice // _validate_slice
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def _validate_slice(s): // def _validate_slice(s):
// if s.step not in (-1, 1, None): // if s.step not in (-1, 1, None):
// raise ValueError(( // raise ValueError((
// 'SlicedAxis cant currently handle a step size other ' // 'SlicedAxis cant currently handle a step size other '
// 'than -1, 1 or None. Was given {step} in slice {slice}' // 'than -1, 1 or None. Was given {step} in slice {slice}'
// ).format( // ).format(
// step=s.step, // step=s.step,
// slice=s, // slice=s,
// )) // ))
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// slice_axis // slice_axis
// Slice an axis, return complete new axis // Slice an axis, return complete new axis
// TODO: deprecate this after the axis refactoring // TODO: deprecate this after the axis refactoring
// //
// Arguments: // Arguments:
// axis: the axis to be sliced // axis: the axis to be sliced
// s: slice // s: slice
// //
// Returns: // Returns:
// Axis instance, the new sliced axis // Axis instance, the new sliced axis
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def slice_axis(axis, s): // def slice_axis(axis, s):
Axis slice_axis(const Axis& axis, const slice& s); Axis slice_axis(const Axis& axis, const slice& s);
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// duplicates // duplicates
// Returns a list of Axis objects which have duplicate names in arr // Returns a list of Axis objects which have duplicate names in arr
// //
// Arguments: // Arguments:
// arr: The iterable of Axis objects to check for duplicates in. // arr: The iterable of Axis objects to check for duplicates in.
// //
// Returns: // Returns:
// list of Axis: duplicate Axis found in arr // list of Axis: duplicate Axis found in arr
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
std::vector<std::string> duplicates(const std::vector<Axis>& ax); std::vector<std::string> duplicates(const std::vector<Axis>& ax);
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// with_args_as_axes // with_args_as_axes
// A decorator to cast arguments to axes. // A decorator to cast arguments to axes.
// //
// Arguments: // Arguments:
// f: The function to be decorated. // f: The function to be decorated.
// //
// Returns: // Returns:
// The decorated function. // The decorated function.
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def with_args_as_axes(f): // def with_args_as_axes(f):
// @wraps(f) // @wraps(f)
// def wrapper(*args): // def wrapper(*args):
// """ // """
// The decorated function. Performs the conversion // The decorated function. Performs the conversion
// to Axes. // to Axes.
// Arguments: // Arguments:
// *args: Arguments intended for the original function. // *args: Arguments intended for the original function.
// Returns: // Returns:
// Return value of the original function. // Return value of the original function.
// """ // """
// args = [Axes(arg) for arg in args] // args = [Axes(arg) for arg in args]
// return f(*args) // return f(*args)
// return wrapper // return wrapper
//================================================================================================ //================================================================================================
// Axes // Axes
// An Axes is a tuple of Axis objects used as a label for a tensor's // An Axes is a tuple of Axis objects used as a label for a tensor's
// dimensions. // dimensions.
//================================================================================================ //================================================================================================
class Axes class Axes
{ {
public: public:
std::vector<Axis> axes; std::vector<Axis> axes;
uuid_type uuid; uuid_type uuid;
...@@ -706,47 +705,47 @@ public: ...@@ -706,47 +705,47 @@ public:
std::vector<Axis> convert(const Axes& ax); std::vector<Axis> convert(const Axes& ax);
std::vector<Axis> convert(const std::vector<Axes>& ax); std::vector<Axis> convert(const std::vector<Axes>& ax);
private: private:
void check_duplicates(); void check_duplicates();
}; };
//================================================================================================
// DuplicateAxisNames
//================================================================================================
// class DuplicateAxisNames(ValueError):
// def __init__(self, message, duplicate_axis_names):
// super(DuplicateAxisNames, self).__init__(message)
// self.duplicate_axis_names = duplicate_axis_names
//================================================================================================
// IncompatibleAxesError
//================================================================================================
// class IncompatibleAxesError(ValueError):
// pass
//================================================================================================
// UnmatchedAxesError
//================================================================================================
// class UnmatchedAxesError(IncompatibleAxesError):
// pass
//================================================================================================ //================================================================================================
// AxesMap // DuplicateAxisNames
// AxesMap provides a way to define a axis name mapping: {Axis.name: Axis.name} and //================================================================================================
// then apply this mapping to an Axes and get new Axes out.
// // class DuplicateAxisNames(ValueError):
// Right now AxesMap is implemented as immutible because I didn't want to deal with // def __init__(self, message, duplicate_axis_names):
// enforcing _assert_valid_axes_map on every method which mutates a dict and I didn't // super(DuplicateAxisNames, self).__init__(message)
// need a mutable datastructure anyway. Feel free to make it mutable and add in
// invariant enforcement. // self.duplicate_axis_names = duplicate_axis_names
//================================================================================================
class AxesMap : public std::map<std::string, std::string> //================================================================================================
{ // IncompatibleAxesError
public: //================================================================================================
// class IncompatibleAxesError(ValueError):
// pass
//================================================================================================
// UnmatchedAxesError
//================================================================================================
// class UnmatchedAxesError(IncompatibleAxesError):
// pass
//================================================================================================
// AxesMap
// AxesMap provides a way to define a axis name mapping: {Axis.name: Axis.name} and
// then apply this mapping to an Axes and get new Axes out.
//
// Right now AxesMap is implemented as immutible because I didn't want to deal with
// enforcing _assert_valid_axes_map on every method which mutates a dict and I didn't
// need a mutable datastructure anyway. Feel free to make it mutable and add in
// invariant enforcement.
//================================================================================================
class AxesMap : public std::map<std::string, std::string>
{
public:
AxesMap(const std::pair<std::string, std::string>&); AxesMap(const std::pair<std::string, std::string>&);
AxesMap(std::initializer_list<std::pair<std::string, std::string>>); AxesMap(std::initializer_list<std::pair<std::string, std::string>>);
...@@ -762,74 +761,70 @@ public: ...@@ -762,74 +761,70 @@ public:
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
Axis map_axis(const Axis& old_axis) const; Axis map_axis(const Axis& old_axis) const;
private: private:
std::map<std::string, std::set<std::string>> duplicate_axis_names(); std::map<std::string, std::set<std::string>> duplicate_axis_names();
void assert_valid_axes_map(); void assert_valid_axes_map();
public: public:
// def invert(self): // def invert(self):
// return {v: k for k, v in self.items()} // return {v: k for k, v in self.items()}
}; };
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// _reduce_nested // _reduce_nested
// Reduces a nested sequence by applying a function to each // Reduces a nested sequence by applying a function to each
// of its elements and returns an aggregation. // of its elements and returns an aggregation.
// //
// Arguments: // Arguments:
// elem: The object to be reduced, either a sequence // elem: The object to be reduced, either a sequence
// or a singleton. // or a singleton.
// agg: A variable holding information collected // agg: A variable holding information collected
// as the sequence is collapsed. // as the sequence is collapsed.
// func: A function to augment the aggregate by processing // func: A function to augment the aggregate by processing
// a singleton. Should have the form func(agg, elem) -> agg // a singleton. Should have the form func(agg, elem) -> agg
// //
// Returns: // Returns:
// agg: The final aggregate returned by the function. // agg: The final aggregate returned by the function.
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def _reduce_nested(elem, agg, func): // def _reduce_nested(elem, agg, func):
// if isinstance(elem, collections.Iterable): // if isinstance(elem, collections.Iterable):
// for sub in elem: // for sub in elem:
// agg = _reduce_nested(sub, agg, func) // agg = _reduce_nested(sub, agg, func)
// return agg // return agg
// else: // else:
// return func(agg, elem) // return func(agg, elem)
//================================================================================================ //================================================================================================
// FlattenedAxis // FlattenedAxis
// A FlattenedAxis has length which is the product of the lengths of all // A FlattenedAxis has length which is the product of the lengths of all
// Axis in the axes. The original Axes object is stored so that we can later // Axis in the axes. The original Axes object is stored so that we can later
// unflatten this Axis back to its original component Axis. // unflatten this Axis back to its original component Axis.
// //
// Notes: since we allows Axis to have duplicated names globally, NameableValue // Notes: since we allows Axis to have duplicated names globally, NameableValue
// is not used here. // is not used here.
//================================================================================================ //================================================================================================
class FlattenedAxis : public Axis class FlattenedAxis : public Axis
{ {
public: public:
FlattenedAxis(const std::vector<Axis>& list, const std::string& new_name = ""); FlattenedAxis(const std::vector<Axis>& list, const std::string& new_name = "");
virtual ~FlattenedAxis() {} virtual ~FlattenedAxis() {}
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
// Returns: // Returns:
// True is this is a FlattendAxis. // True is this is a FlattendAxis.
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
bool is_flattened() const { return true; } bool is_flattened() const { return true; }
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
// Returns: // Returns:
// Whether this axes contains no collapsed axes. // Whether this axes contains no collapsed axes.
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
bool empty() const { return axes.size() == 0; } bool empty() const { return axes.size() == 0; }
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
// Returns: // Returns:
// Whether this axes contains exactly one collapsed axes. // Whether this axes contains exactly one collapsed axes.
//-------------------------------------------------------------------------------------------- //--------------------------------------------------------------------------------------------
bool single() const { return axes.size() == 0; } bool single() const { return axes.size() == 0; }
bool operator==(const Axis& other) const; bool operator==(const Axis& other) const;
// def __hash__(self): // def __hash__(self):
...@@ -841,96 +836,96 @@ public: ...@@ -841,96 +836,96 @@ public:
// return 'FlattenedAxis(%s)' % ', '.join(repr(axis) for axis in self.axes) // return 'FlattenedAxis(%s)' % ', '.join(repr(axis) for axis in self.axes)
std::vector<Axis> axes; std::vector<Axis> axes;
}; };
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// default_dtype // default_dtype
// Reduces a nested tuple describing the strides of a tensor // Reduces a nested tuple describing the strides of a tensor
// into a tuple giving the stride of each of its dimensions. // into a tuple giving the stride of each of its dimensions.
// //
// Arguments: // Arguments:
// strides: The nested tuple. // strides: The nested tuple.
// //
// Returns: // Returns:
// strides: The tuple of strides. // strides: The tuple of strides.
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def reduce_strides(strides): // def reduce_strides(strides):
// return tuple(int(_reduce_nested(elem, float('inf'), min)) // return tuple(int(_reduce_nested(elem, float('inf'), min))
// for elem in strides) // for elem in strides)
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// _make_stride // _make_stride
// Generates a nested tuple that provides the striding information // Generates a nested tuple that provides the striding information
// for an occurrence of axis. If the axis is a FlattenedAxis, the // for an occurrence of axis. If the axis is a FlattenedAxis, the
// stride will be a tuple containing the strides of each collapsed // stride will be a tuple containing the strides of each collapsed
// axis. Otherwise, the stride will be an integer. // axis. Otherwise, the stride will be an integer.
// //
// Arguments: // Arguments:
// inner_size: The total size of all dimensions smaller than this // inner_size: The total size of all dimensions smaller than this
// axis, i.e. all axes to the right of this one when they are // axis, i.e. all axes to the right of this one when they are
// laid out in c-contiguous order. // laid out in c-contiguous order.
// axis: The axis for which we are generating a stride. // axis: The axis for which we are generating a stride.
// fsz: A nested tuple supplying the sizes of each dimension collapsed // fsz: A nested tuple supplying the sizes of each dimension collapsed
// into the axis. The size may be larger than the length of the axis. // into the axis. The size may be larger than the length of the axis.
// //
// Returns: // Returns:
// inner_size: The total size of this axis and all smaller dimensions. // inner_size: The total size of this axis and all smaller dimensions.
// stride: The stride given to the axis. // stride: The stride given to the axis.
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def _make_stride(inner_size, axis, fsz): // def _make_stride(inner_size, axis, fsz):
// if axis.is_flattened: // if axis.is_flattened:
// return _make_strides(inner_size, axis.axes, fsz) // return _make_strides(inner_size, axis.axes, fsz)
// else: // else:
// stride = inner_size // stride = inner_size
// inner_size *= fsz // inner_size *= fsz
// return inner_size, stride // return inner_size, stride
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// _make_strides // _make_strides
// Generates a tuple of strides for a set of axes. See _make_stride // Generates a tuple of strides for a set of axes. See _make_stride
// for a description of the stride given to each axis. // for a description of the stride given to each axis.
// //
// Arguments: // Arguments:
// inner_size: The total size of all dimensions smaller than // inner_size: The total size of all dimensions smaller than
// the axes. // the axes.
// axes: The axes for which we are generating strides. // axes: The axes for which we are generating strides.
// full_sizes: The size of each axis. // full_sizes: The size of each axis.
// //
// Returns: // Returns:
// inner_size: The total size of these axes and all smaller dimensions. // inner_size: The total size of these axes and all smaller dimensions.
// strides: The strides generated for the axes. // strides: The strides generated for the axes.
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// def _make_strides(inner_size, axes, full_sizes): // def _make_strides(inner_size, axes, full_sizes):
// full_strides = [] // full_strides = []
// for axis, fsz in reversed(list(zip(axes, full_sizes))): // for axis, fsz in reversed(list(zip(axes, full_sizes))):
// inner_size, stride = _make_stride(inner_size, axis, fsz) // inner_size, stride = _make_stride(inner_size, axis, fsz)
// full_strides.append(stride) // full_strides.append(stride)
// return inner_size, tuple(reversed(full_strides)) // return inner_size, tuple(reversed(full_strides))
//================================================================================================ //================================================================================================
// TensorDescription // TensorDescription
// Description of a tensor that will be allocated in hardware. // Description of a tensor that will be allocated in hardware.
// //
// Names the tensor's dimensions with axes and holds pointers to the // Names the tensor's dimensions with axes and holds pointers to the
// buffer allocated by the analysis and the backend tensor value // buffer allocated by the analysis and the backend tensor value
// (e.g. a cpu or gpu tensor). // (e.g. a cpu or gpu tensor).
// //
// Arguments: // Arguments:
// axes: Axes of the tensor. // axes: Axes of the tensor.
// base: If a view, the viewed tensor's description. // base: If a view, the viewed tensor's description.
// dtype: The type of the tensor. // dtype: The type of the tensor.
// full_strides: The strides of each axis. // full_strides: The strides of each axis.
// full_sizes: The allocated size of each axis (may be larger than the axis). // full_sizes: The allocated size of each axis (may be larger than the axis).
// offset: An offset into the viewed tensor. // offset: An offset into the viewed tensor.
// next_tensor_decription: In a reshape, tensor description of reshaped tensor. // next_tensor_decription: In a reshape, tensor description of reshaped tensor.
// is_persistent: The tensor should be persistent, i.e. survive from computation to // is_persistent: The tensor should be persistent, i.e. survive from computation to
// computation. // computation.
// is_input: The device tensor can be written from the host. // is_input: The device tensor can be written from the host.
// **kwargs: Additional args for related classes. // **kwargs: Additional args for related classes.
//================================================================================================ //================================================================================================
class TensorDescription : public NameableValue class TensorDescription : public NameableValue
{ {
public: public:
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
//! constructor //! constructor
//!----------------------------------------------------------------------------------- //!-----------------------------------------------------------------------------------
...@@ -1487,7 +1482,7 @@ public: ...@@ -1487,7 +1482,7 @@ public:
ngraph::tensor_size full_sizes; ngraph::tensor_size full_sizes;
ngraph::tensor_stride full_strides; ngraph::tensor_stride full_strides;
tensor_description_ptr next_tensor_description; tensor_description_ptr next_tensor_description;
}; };
} // end of namespace ngraph } // end of namespace ngraph
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cmath>
#include <exception>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <exception>
#include <cmath>
#include "exop.hpp" #include "exop.hpp"
#include "op_graph.hpp" #include "op_graph.hpp"
......
...@@ -15,67 +15,66 @@ ...@@ -15,67 +15,66 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <string> #include <list>
#include <map> #include <map>
#include <memory> #include <memory>
#include <vector>
#include <sstream>
#include <set> #include <set>
#include <list> #include <sstream>
#include <string>
#include <vector>
#include "axes.hpp"
#include "mock.hpp" #include "mock.hpp"
#include "op_graph.hpp" #include "op_graph.hpp"
#include "axes.hpp"
namespace ngraph namespace ngraph
{ {
// forward declaration. This will hopefully go away
// forward declaration. This will hopefully go away class ExecutionGraph;
class ExecutionGraph; class TensorDescription;
class TensorDescription; class InputDecl;
class InputDecl; class OutputDecl;
class OutputDecl; class TensorDecl;
class TensorDecl; class TensorViewDecl;
class TensorViewDecl; class ExOp;
class ExOp; class Op;
class Op; class ComputationDecl;
class ComputationDecl; class ExOpBlock;
class ExOpBlock; class ExecutionState;
class ExecutionState;
using output_decl_ptr = std::shared_ptr<OutputDecl>;
using output_decl_ptr = std::shared_ptr<OutputDecl>; using input_decl_ptr = std::shared_ptr<InputDecl>;
using input_decl_ptr = std::shared_ptr<InputDecl>; using tensor_decl_ptr = std::shared_ptr<TensorDecl>;
using tensor_decl_ptr = std::shared_ptr<TensorDecl>; using tensor_view_decl_ptr = std::shared_ptr<TensorViewDecl>;
using tensor_view_decl_ptr = std::shared_ptr<TensorViewDecl>; using exop_ptr = std::shared_ptr<ExOp>;
using exop_ptr = std::shared_ptr<ExOp>; using computation_decl_ptr = std::shared_ptr<ComputationDecl>;
using computation_decl_ptr = std::shared_ptr<ComputationDecl>; using execution_graph_ptr = std::shared_ptr<ExecutionGraph>;
using execution_graph_ptr = std::shared_ptr<ExecutionGraph>; using exop_block_ptr = std::shared_ptr<ExOpBlock>;
using exop_block_ptr = std::shared_ptr<ExOpBlock>; using tensor_ptr = std::shared_ptr<TensorInterface>;
using tensor_ptr = std::shared_ptr<TensorInterface>; using transformer_ptr = std::shared_ptr<Transformer>;
using transformer_ptr = std::shared_ptr<Transformer>; using execution_state_ptr = std::shared_ptr<ExecutionState>;
using execution_state_ptr = std::shared_ptr<ExecutionState>;
//================================================================================================
//================================================================================================ // OutputDecl
// OutputDecl // One value computed by an exop
// One value computed by an exop //
// // Arguments:
// Arguments: // exop: The exop.
// exop: The exop. // pos: The position of the value, defaults to 0.
// pos: The position of the value, defaults to 0. // tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value. // write_view: The tensor view where the value is written.
// write_view: The tensor view where the value is written. //
// // Attributes:
// Attributes: // exop: The exop.
// exop: The exop. // pos: The position of the value.
// pos: The position of the value. // tensor_description: Tensor description of the value.
// tensor_description: Tensor description of the value. // write_view: The tensor view where the value is written.
// write_view: The tensor view where the value is written. // value_users: Arguments using this value.
// value_users: Arguments using this value. //================================================================================================
//================================================================================================
class OutputDecl
class OutputDecl {
{ public:
public:
OutputDecl(const ExOp& _exop, size_t _pos, tensor_decl_ptr, tensor_description_ptr); OutputDecl(const ExOp& _exop, size_t _pos, tensor_decl_ptr, tensor_description_ptr);
tensor_decl_ptr tensor_decl(); tensor_decl_ptr tensor_decl();
void tensor_decl(tensor_decl_ptr tensor_decl); void tensor_decl(tensor_decl_ptr tensor_decl);
...@@ -95,29 +94,29 @@ public: ...@@ -95,29 +94,29 @@ public:
tensor_decl_ptr __tensor; tensor_decl_ptr __tensor;
tensor_view_decl_ptr __write_view; tensor_view_decl_ptr __write_view;
std::set<InputDecl*> value_users; std::set<InputDecl*> value_users;
}; };
//================================================================================================ //================================================================================================
// InputDecl // InputDecl
// An argument for an exop. // An argument for an exop.
// //
// Arguments: // Arguments:
// exop: The exop. // exop: The exop.
// pos: The position of the value, defaults to 0. // pos: The position of the value, defaults to 0.
// tensor_description: Tensor description of the value. // tensor_description: Tensor description of the value.
// read_view: The tensor view where the value is read from. // read_view: The tensor view where the value is read from.
// //
// Attributes: // Attributes:
// exop: The exop. // exop: The exop.
// pos: The position of the value. // pos: The position of the value.
// tensor_description: Tensor description of the value. // tensor_description: Tensor description of the value.
// read_view: The tensor view where the value is read from. // read_view: The tensor view where the value is read from.
// value: Arguments supplying this value. // value: Arguments supplying this value.
//================================================================================================ //================================================================================================
class InputDecl class InputDecl
{ {
public: public:
InputDecl(const ExOp& _exop, InputDecl(const ExOp& _exop,
size_t _pos, size_t _pos,
tensor_description_ptr _tensor_description, tensor_description_ptr _tensor_description,
...@@ -134,37 +133,37 @@ public: ...@@ -134,37 +133,37 @@ public:
tensor_description_ptr tensor_description; tensor_description_ptr tensor_description;
tensor_view_decl_ptr read_view; tensor_view_decl_ptr read_view;
OutputDecl* m_value; OutputDecl* m_value;
}; };
//================================================================================================ //================================================================================================
// ExecutionGraphElt // ExecutionGraphElt
// An element of an exection graph. // An element of an exection graph.
// //
// Arguments: // Arguments:
// execution_graph: The execution graph that indexes this exop. // execution_graph: The execution graph that indexes this exop.
// //
// Attributes: // Attributes:
// execution_graph: The execution graph that indexes this exop. // execution_graph: The execution graph that indexes this exop.
//================================================================================================ //================================================================================================
class ExecutionGraphElt class ExecutionGraphElt
{ {
public: public:
ExecutionGraphElt(ExecutionGraph& eg) ExecutionGraphElt(ExecutionGraph& eg)
: execution_graph{eg} : execution_graph{eg}
{ {
} }
ExecutionGraph& execution_graph; ExecutionGraph& execution_graph;
}; };
//================================================================================================ //================================================================================================
// ExOp // ExOp
//================================================================================================ //================================================================================================
class ExOp : public ExecutionGraphElt class ExOp : public ExecutionGraphElt
{ {
public: public:
// An exop that indicates an op to be executed. // An exop that indicates an op to be executed.
// The op might be different from what was originally found in the computation graph. // The op might be different from what was originally found in the computation graph.
...@@ -220,17 +219,18 @@ public: ...@@ -220,17 +219,18 @@ public:
std::vector<tensor_decl_ptr> liveness_free_list; std::vector<tensor_decl_ptr> liveness_free_list;
std::vector<tensor_decl_ptr> liveness_new_list; std::vector<tensor_decl_ptr> liveness_new_list;
std::vector<InputDecl> args; std::vector<InputDecl> args;
std::vector<InputDecl*> write_args; // TODO: Kludge until we have values with writers/readers std::vector<InputDecl*>
write_args; // TODO: Kludge until we have values with writers/readers
std::vector<OutputDecl> values; std::vector<OutputDecl> values;
}; };
//================================================================================================ //================================================================================================
// TensorDecl // TensorDecl
//================================================================================================ //================================================================================================
class TensorDecl : public ExecutionGraphElt class TensorDecl : public ExecutionGraphElt
{ {
public: public:
// Allocate for a tensor. // Allocate for a tensor.
// Arguments: // Arguments:
...@@ -294,15 +294,15 @@ public: ...@@ -294,15 +294,15 @@ public:
bool is_compile_only; bool is_compile_only;
tensor_ptr initial_value; tensor_ptr initial_value;
tensor_decl_ptr source_tensor; tensor_decl_ptr source_tensor;
}; };
//================================================================================================ //================================================================================================
// ExOpBlock // ExOpBlock
//================================================================================================ //================================================================================================
class ExOpBlock : public ExecutionGraphElt class ExOpBlock : public ExecutionGraphElt
{ {
public: public:
// Sequentially execute a list of exops. // Sequentially execute a list of exops.
// Attributes: // Attributes:
...@@ -312,7 +312,8 @@ public: ...@@ -312,7 +312,8 @@ public:
// root_set: Set of exops whose values are needed. // root_set: Set of exops whose values are needed.
ExOpBlock(ComputationDecl& cgraph); ExOpBlock(ComputationDecl& cgraph);
bool is_exop_end_of_list(); bool is_exop_end_of_list();
void add_ops(std::initializer_list<computation_op_ptr> roots, exop_ptr after_exop = nullptr); void add_ops(std::initializer_list<computation_op_ptr> roots,
exop_ptr after_exop = nullptr);
exop_ptr add_op(op_ptr op, exop_ptr after_exop); exop_ptr add_op(op_ptr op, exop_ptr after_exop);
exop_ptr add_exop(exop_ptr exop, exop_ptr after_exop = nullptr); exop_ptr add_exop(exop_ptr exop, exop_ptr after_exop = nullptr);
void move_exop_to_after_exop(exop_ptr exop, exop_ptr after_exop); void move_exop_to_after_exop(exop_ptr exop, exop_ptr after_exop);
...@@ -336,17 +337,16 @@ public: ...@@ -336,17 +337,16 @@ public:
// replacement for next_exop, prev_exop // replacement for next_exop, prev_exop
std::list<exop_ptr>::iterator begin() { return op_list.begin(); } std::list<exop_ptr>::iterator begin() { return op_list.begin(); }
std::list<exop_ptr>::iterator end() { return op_list.end(); } std::list<exop_ptr>::iterator end() { return op_list.end(); }
std::list<exop_ptr> op_list; std::list<exop_ptr> op_list;
}; };
//================================================================================================ //================================================================================================
// TensorViewDecl // TensorViewDecl
//================================================================================================ //================================================================================================
class TensorViewDecl : public ExecutionGraphElt class TensorViewDecl : public ExecutionGraphElt
{ {
public: public:
// Declare a view of a tensor. // Declare a view of a tensor.
// Arguments: // Arguments:
...@@ -373,17 +373,17 @@ public: ...@@ -373,17 +373,17 @@ public:
std::set<InputDecl*> readers; std::set<InputDecl*> readers;
std::set<OutputDecl*> writers; std::set<OutputDecl*> writers;
OutputDecl* value; OutputDecl* value;
}; };
// static exop_ptr _default_default; // static exop_ptr _default_default;
//================================================================================================ //================================================================================================
// ComputationDecl // ComputationDecl
//================================================================================================ //================================================================================================
class ComputationDecl : public ExecutionGraphElt class ComputationDecl : public ExecutionGraphElt
{ {
public: public:
// One computation to be run. // One computation to be run.
// Every computation has its own execution graph. Persistent tensors are shared // Every computation has its own execution graph. Persistent tensors are shared
...@@ -406,15 +406,15 @@ public: ...@@ -406,15 +406,15 @@ public:
exop_block_ptr exop_block; exop_block_ptr exop_block;
exop_ptr returns; exop_ptr returns;
std::set<ExOp*> values; std::set<ExOp*> values;
}; };
//================================================================================================ //================================================================================================
// ExecutionState // ExecutionState
//================================================================================================ //================================================================================================
class ExecutionState class ExecutionState
{ {
public: public:
// Proxy for the state of a device. // Proxy for the state of a device.
// Arguments: // Arguments:
...@@ -429,15 +429,15 @@ public: ...@@ -429,15 +429,15 @@ public:
// persistent tensors // persistent tensors
std::map<tensor_description_ptr, tensor_decl_ptr> __tensors_decls; std::map<tensor_description_ptr, tensor_decl_ptr> __tensors_decls;
}; };
//================================================================================================ //================================================================================================
// ExecutionGraph // ExecutionGraph
//================================================================================================ //================================================================================================
class ExecutionGraph class ExecutionGraph
{ {
public: public:
// Information for compiling a computation_op. // Information for compiling a computation_op.
// Arguments: // Arguments:
...@@ -452,6 +452,6 @@ public: ...@@ -452,6 +452,6 @@ public:
// temporary tensors // temporary tensors
std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls; std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls;
computation_decl_ptr computation_decl; computation_decl_ptr computation_decl;
}; };
} // end namespace ngraph } // end namespace ngraph
...@@ -14,49 +14,47 @@ ...@@ -14,49 +14,47 @@
#pragma once #pragma once
#include <string>
#include <memory>
#include <map> #include <map>
#include <vector> #include <memory>
#include <type_traits>
#include <sstream> #include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include "element_type.hpp" #include "element_type.hpp"
namespace ngraph namespace ngraph
{ {
class ExecutionState;
class ExecutionState; class Op;
// class TensorDescription;
class Op; class ComputationOp;
// class TensorDescription;
class ComputationOp;
using computation_op_ptr = std::shared_ptr<ComputationOp>; using computation_op_ptr = std::shared_ptr<ComputationOp>;
using op_ptr = std::shared_ptr<Op>; using op_ptr = std::shared_ptr<Op>;
using scalar_t = float; using scalar_t = float;
//================================================================================================ //================================================================================================
// TensorInterface // TensorInterface
//================================================================================================ //================================================================================================
class TensorInterface class TensorInterface
{ {
public: public:
virtual ~TensorInterface() {} virtual ~TensorInterface() {}
virtual const ElementType& element_type() const = 0; virtual const ElementType& element_type() const = 0;
virtual std::string value_string() const = 0; virtual std::string value_string() const = 0;
}; };
//================================================================================================ //================================================================================================
// Tensor // Tensor
//================================================================================================ //================================================================================================
template <typename T> template <typename T>
class Tensor : public TensorInterface class Tensor : public TensorInterface
{ {
public: public:
Tensor(const T& val) Tensor(const T& val)
: m_value{val} : m_value{val}
, m_element_type{element_type_float} , m_element_type{element_type_float}
...@@ -64,9 +62,7 @@ public: ...@@ -64,9 +62,7 @@ public:
} }
virtual ~Tensor() {} virtual ~Tensor() {}
const ElementType& element_type() const override { return m_element_type; } const ElementType& element_type() const override { return m_element_type; }
std::string value_string() const override std::string value_string() const override
{ {
std::string rc = "WTF"; std::string rc = "WTF";
...@@ -79,104 +75,104 @@ public: ...@@ -79,104 +75,104 @@ public:
return rc; return rc;
} }
private: private:
T m_value; T m_value;
ElementType m_element_type; ElementType m_element_type;
}; };
//================================================================================================ //================================================================================================
// Transformer // Transformer
//================================================================================================ //================================================================================================
class Transformer class Transformer
{ {
public: public:
virtual ~Transformer() {} virtual ~Transformer() {}
virtual ExecutionState& execution_state() = 0; virtual ExecutionState& execution_state() = 0;
}; };
//================================================================================================ //================================================================================================
// TensorDescription // TensorDescription
//================================================================================================ //================================================================================================
// class TensorDescription // class TensorDescription
// { // {
// public: // public:
// virtual ~TensorDescription(); // virtual ~TensorDescription();
// virtual axes_key_t axes_key() const = 0; // virtual axes_key_t axes_key() const = 0;
// virtual std::string name() const = 0; // virtual std::string name() const = 0;
// virtual std::vector<size_t> shape() const = 0; // virtual std::vector<size_t> shape() const = 0;
// virtual std::shared_ptr<TensorDescription> base() = 0; // virtual std::shared_ptr<TensorDescription> base() = 0;
// virtual ElementType element_type() const = 0; // virtual ElementType element_type() const = 0;
// virtual size_t tensor_size() = 0; // virtual size_t tensor_size() = 0;
// virtual bool is_persistent() = 0; // virtual bool is_persistent() = 0;
// virtual bool is_input() = 0; // virtual bool is_input() = 0;
// }; // };
//================================================================================================ //================================================================================================
// Op // Op
//================================================================================================ //================================================================================================
// class Op // class Op
// { // {
// // Any operation that can be in an AST. // // Any operation that can be in an AST.
// // Arguments: // // Arguments:
// // args: Values used by this node. // // args: Values used by this node.
// // const: The value of a constant Op, or None, // // const: The value of a constant Op, or None,
// // constant (bool): The Op is constant. Default False. // // constant (bool): The Op is constant. Default False.
// // forward: If not None, the node to use instead of this node. // // forward: If not None, the node to use instead of this node.
// // metadata: String key value dictionary for frontend metadata. // // metadata: String key value dictionary for frontend metadata.
// // kwargs: Args defined in related classes. // // kwargs: Args defined in related classes.
// // Attributes: // // Attributes:
// // const: The value of a constant. // // const: The value of a constant.
// // constant (bool): The value is constant. // // constant (bool): The value is constant.
// // control_deps (OrderedSet): Ops in addtion to args that must run before this op. // // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // persistent (bool): The value will be retained from computation to computation and // // persistent (bool): The value will be retained from computation to computation and
// // not shared. Always True if reference is set. // // not shared. Always True if reference is set.
// // metadata: Dictionary with of string keys and values used for attaching // // metadata: Dictionary with of string keys and values used for attaching
// // arbitrary metadata to nodes. // // arbitrary metadata to nodes.
// // trainable: The value is trainable. // // trainable: The value is trainable.
// public: // public:
// virtual ~Op() {} // virtual ~Op() {}
// virtual std::string name() const = 0; // virtual std::string name() const = 0;
// virtual tensor_description_ptr tensor_description() = 0; // virtual tensor_description_ptr tensor_description() = 0;
// virtual op_ptr tensor() = 0; // virtual op_ptr tensor() = 0;
// virtual bool is_tensor_op() = 0; // virtual bool is_tensor_op() = 0;
// virtual bool is_state_op() const = 0; // virtual bool is_state_op() const = 0;
// virtual bool is_sequencing_op() const = 0; // virtual bool is_sequencing_op() const = 0;
// virtual op_ptr effective_tensor_op() = 0; // virtual op_ptr effective_tensor_op() = 0;
// virtual const std::vector<op_ptr>& all_deps() const = 0; // virtual const std::vector<op_ptr>& all_deps() const = 0;
// // ops // // ops
// // TODO support multiple types // // TODO support multiple types
// static op_ptr constant(float value) // static op_ptr constant(float value)
// { // {
// op_ptr = make_shared<LiteralScalarOp>(value); // op_ptr = make_shared<LiteralScalarOp>(value);
// } // }
// }; // };
//================================================================================================ //================================================================================================
// TensorOp // TensorOp
//================================================================================================ //================================================================================================
// class TensorOp : public Op // class TensorOp : public Op
// { // {
// public: // public:
// std::string name() const override { return "TensorOp"; } // std::string name() const override { return "TensorOp"; }
// tensor_description_ptr tensor_description() override { return nullptr; } // tensor_description_ptr tensor_description() override { return nullptr; }
// op_ptr tensor() override { return nullptr; } // op_ptr tensor() override { return nullptr; }
// bool is_tensor_op() override { return true; } // bool is_tensor_op() override { return true; }
// bool is_state_op() const override { return false; } // bool is_state_op() const override { return false; }
// op_ptr effective_tensor_op() override { return nullptr; } // op_ptr effective_tensor_op() override { return nullptr; }
// const std::vector<op_ptr>& all_deps() const override { return m_all_deps; } // const std::vector<op_ptr>& all_deps() const override { return m_all_deps; }
// private: // private:
// std::vector<op_ptr> m_all_deps; // std::vector<op_ptr> m_all_deps;
// }; // };
} // end of namespace ngraph } // end of namespace ngraph
...@@ -14,24 +14,21 @@ ...@@ -14,24 +14,21 @@
#pragma once #pragma once
#include "mock.hpp"
#include "exop.hpp" #include "exop.hpp"
#include "mock.hpp"
namespace ngraph namespace ngraph
{ {
//================================================================================================
//================================================================================================ // CpuTransformer
// CpuTransformer //================================================================================================
//================================================================================================ class CpuTransformer : public Transformer
class CpuTransformer : public Transformer {
{ public:
public:
virtual ~CpuTransformer() {} virtual ~CpuTransformer() {}
ExecutionState& execution_state() override { return m_execution_state; } ExecutionState& execution_state() override { return m_execution_state; }
private:
private:
ExecutionState m_execution_state; ExecutionState m_execution_state;
}; };
} // end namespace ngraph } // end namespace ngraph
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include <vector>
#include <memory> #include <memory>
#include <vector>
#include "element_type.hpp" #include "element_type.hpp"
#include "strides.hpp" #include "strides.hpp"
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include <sstream> #include <sstream>
#include "op_graph.hpp"
#include "axes.hpp" #include "axes.hpp"
#include "op_graph.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -2794,7 +2794,9 @@ ElementWiseOp::ElementWiseOp() ...@@ -2794,7 +2794,9 @@ ElementWiseOp::ElementWiseOp()
{ {
} }
void ElementWiseOp::ElementWiseOp_init(std::vector<op_ptr>, Axes) {} void ElementWiseOp::ElementWiseOp_init(std::vector<op_ptr>, Axes)
{
}
//================================================================================================ //================================================================================================
// UnaryElementWiseOp // UnaryElementWiseOp
......
This source diff could not be displayed because it is too large. You can view the blob instead.
#pragma once #pragma once
#include <algorithm>
#include <functional> #include <functional>
#include <vector>
#include <initializer_list> #include <initializer_list>
#include <iostream> #include <iostream>
#include <algorithm> #include <vector>
#include "util.hpp" #include "util.hpp"
...@@ -51,7 +51,6 @@ public: ...@@ -51,7 +51,6 @@ public:
bool is_list() const { return m_is_list; } bool is_list() const { return m_is_list; }
T get_value() const { return m_value; } T get_value() const { return m_value; }
const std::vector<tree>& get_list() const { return m_list; } const std::vector<tree>& get_list() const { return m_list; }
static void traverse_tree(tree& s, std::function<void(T*)> func) static void traverse_tree(tree& s, std::function<void(T*)> func)
{ {
if (s.is_list()) if (s.is_list())
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <map>
#include <iomanip> #include <iomanip>
#include <map>
#include "util.hpp" #include "util.hpp"
......
...@@ -14,23 +14,22 @@ ...@@ -14,23 +14,22 @@
#pragma once #pragma once
#include <string>
#include <sstream>
#include <vector>
#include <chrono>
#include <algorithm> #include <algorithm>
#include <map> #include <chrono>
#include <iostream> #include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <vector>
namespace ngraph namespace ngraph
{ {
class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics;
class stopwatch; template <typename T>
extern std::map<std::string, stopwatch*> stopwatch_statistics; std::string join(const T& v, const std::string& sep)
{
template <typename T>
std::string join(const T& v, const std::string& sep)
{
std::ostringstream ss; std::ostringstream ss;
for (const auto& x : v) for (const auto& x : v)
{ {
...@@ -41,11 +40,11 @@ std::string join(const T& v, const std::string& sep) ...@@ -41,11 +40,11 @@ std::string join(const T& v, const std::string& sep)
ss << x; ss << x;
} }
return ss.str(); return ss.str();
} }
template <typename U, typename T> template <typename U, typename T>
bool contains(const U& container, const T& obj) bool contains(const U& container, const T& obj)
{ {
bool rc = false; bool rc = false;
for (auto o : container) for (auto o : container)
{ {
...@@ -56,11 +55,11 @@ bool contains(const U& container, const T& obj) ...@@ -56,11 +55,11 @@ bool contains(const U& container, const T& obj)
} }
} }
return rc; return rc;
} }
template <typename U, typename T> template <typename U, typename T>
bool contains_key(const U& container, const T& obj) bool contains_key(const U& container, const T& obj)
{ {
bool rc = false; bool rc = false;
for (auto o : container) for (auto o : container)
{ {
...@@ -71,28 +70,28 @@ bool contains_key(const U& container, const T& obj) ...@@ -71,28 +70,28 @@ bool contains_key(const U& container, const T& obj)
} }
} }
return rc; return rc;
} }
template <typename U, typename T> template <typename U, typename T>
void remove_from(U& container, const T& obj) void remove_from(U& container, const T& obj)
{ {
auto it = container.find(obj); auto it = container.find(obj);
if (it != container.end()) if (it != container.end())
{ {
container.erase(it); container.erase(it);
} }
} }
size_t hash_combine(const std::vector<size_t>& list); size_t hash_combine(const std::vector<size_t>& list);
void dump(std::ostream& out, const void*, size_t); void dump(std::ostream& out, const void*, size_t);
std::string to_lower(const std::string& s); std::string to_lower(const std::string& s);
std::string trim(const std::string& s); std::string trim(const std::string& s);
std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false); std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
class stopwatch class stopwatch
{ {
public: public:
stopwatch() {} stopwatch() {}
stopwatch(const std::string& name) stopwatch(const std::string& name)
: m_name{name} : m_name{name}
...@@ -149,21 +148,21 @@ public: ...@@ -149,21 +148,21 @@ public:
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; } size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; } size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); } size_t get_total_nanoseconds() const { return m_total_time.count(); }
private:
private:
std::chrono::high_resolution_clock m_clock; std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time; std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
bool m_active = false; bool m_active = false;
std::chrono::nanoseconds m_total_time = std::chrono::high_resolution_clock::duration::zero(); std::chrono::nanoseconds m_total_time =
std::chrono::high_resolution_clock::duration::zero();
std::chrono::nanoseconds m_last_time; std::chrono::nanoseconds m_last_time;
size_t m_total_count = 0; size_t m_total_count = 0;
std::string m_name; std::string m_name;
}; };
template <class InputIt, class BinaryOp> template <class InputIt, class BinaryOp>
typename std::iterator_traits<InputIt>::value_type typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op) reduce(InputIt first, InputIt last, BinaryOp op)
{ {
typename std::iterator_traits<InputIt>::value_type result; typename std::iterator_traits<InputIt>::value_type result;
if (first == last) if (first == last)
...@@ -180,18 +179,18 @@ typename std::iterator_traits<InputIt>::value_type ...@@ -180,18 +179,18 @@ typename std::iterator_traits<InputIt>::value_type
} }
} }
return result; return result;
} }
template <typename T> template <typename T>
T plus(const T& a, const T& b) T plus(const T& a, const T& b)
{ {
return a + b; return a + b;
} }
template <typename T> template <typename T>
T mul(const T& a, const T& b) T mul(const T& a, const T& b)
{ {
return a * b; return a * b;
} }
} // end namespace ngraph } // end namespace ngraph
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#pragma once #pragma once
#include <array> #include <array>
#include <random> #include <cstring>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <cstring> #include <random>
static std::mt19937_64 random_generator; static std::mt19937_64 random_generator;
...@@ -74,7 +74,6 @@ public: ...@@ -74,7 +74,6 @@ public:
} }
bool operator!=(const uuid_type& other) const { return !(*this == other); } bool operator!=(const uuid_type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& out, const uuid_type& id) friend std::ostream& operator<<(std::ostream& out, const uuid_type& id)
{ {
out << id.to_string(); out << id.to_string();
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -310,7 +310,7 @@ TEST(axes, index) ...@@ -310,7 +310,7 @@ TEST(axes, index)
EXPECT_EQ(7, b[1].length()); EXPECT_EQ(7, b[1].length());
} }
TEST(axes, as_nested_list) TEST(axes, DISABLED_as_nested_list)
{ {
Axis C = make_axis(5); Axis C = make_axis(5);
Axis H = make_axis(3); Axis H = make_axis(3);
...@@ -325,7 +325,7 @@ TEST(axes, as_nested_list) ...@@ -325,7 +325,7 @@ TEST(axes, as_nested_list)
FAIL(); FAIL();
} }
TEST(axes, flatten) TEST(axes, DISABLED_flatten)
{ {
Axis C = make_axis(5); Axis C = make_axis(5);
Axis H = make_axis(3); Axis H = make_axis(3);
...@@ -336,7 +336,7 @@ TEST(axes, flatten) ...@@ -336,7 +336,7 @@ TEST(axes, flatten)
EXPECT_TRUE(c.is_flattened()); EXPECT_TRUE(c.is_flattened());
} }
TEST(axes, as_flattened_list) TEST(axes, DISABLED_as_flattened_list)
{ {
FAIL(); FAIL();
} }
...@@ -364,7 +364,7 @@ TEST(axes, hash_axes) ...@@ -364,7 +364,7 @@ TEST(axes, hash_axes)
m2[axes] = 1; m2[axes] = 1;
} }
TEST(axes, reaxe_0d_to_1d) TEST(axes, DISABLED_reaxe_0d_to_1d)
{ {
TensorDescription td{}; TensorDescription td{};
ngraph::ndarray x = random(td); ngraph::ndarray x = random(td);
...@@ -382,7 +382,7 @@ TEST(axes, reaxe_0d_to_1d) ...@@ -382,7 +382,7 @@ TEST(axes, reaxe_0d_to_1d)
FAIL(); FAIL();
} }
TEST(axes, reaxe_0d_to_2d) TEST(axes, DISABLED_reaxe_0d_to_2d)
{ {
// td = TensorDescription(axes=()) // td = TensorDescription(axes=())
// x = random(td) // x = random(td)
...@@ -407,7 +407,7 @@ TEST(axes, reaxe_0d_to_2d) ...@@ -407,7 +407,7 @@ TEST(axes, reaxe_0d_to_2d)
// I started refactoring into smaller pieces as seen in tests above, but // I started refactoring into smaller pieces as seen in tests above, but
// stopped ... // stopped ...
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
TEST(axes, simple_tensors) TEST(axes, DISABLED_simple_tensors)
{ {
// # A simple vector // # A simple vector
// td1 = TensorDescription(axes=[ax_A]) // td1 = TensorDescription(axes=[ax_A])
...@@ -582,7 +582,7 @@ TEST(axes, axes_map) ...@@ -582,7 +582,7 @@ TEST(axes, axes_map)
// assert axes_after == axes_map.map_axes(axes_before) // assert axes_after == axes_map.map_axes(axes_before)
} }
TEST(axes, axes_map_immutable) TEST(axes, DISABLED_axes_map_immutable)
{ {
FAIL(); FAIL();
// axes_map = AxesMap({}) // axes_map = AxesMap({})
...@@ -591,7 +591,7 @@ TEST(axes, axes_map_immutable) ...@@ -591,7 +591,7 @@ TEST(axes, axes_map_immutable)
// axes_map["x"] = "y" // axes_map["x"] = "y"
} }
TEST(axes, axes_map_init_from_axes) TEST(axes, DISABLED_axes_map_init_from_axes)
{ {
FAIL(); FAIL();
// axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")}) // axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")})
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iostream>
#include <chrono> #include <chrono>
#include <iostream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -22,4 +22,6 @@ ...@@ -22,4 +22,6 @@
using namespace ngraph; using namespace ngraph;
TEST(names, name) {} TEST(names, name)
{
}
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <memory> #include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -134,7 +134,9 @@ TEST(util, contains) ...@@ -134,7 +134,9 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8)); EXPECT_FALSE(contains(v1, 8));
} }
TEST(util, remove_from) {} TEST(util, remove_from)
{
}
TEST(util, reduce) TEST(util, reduce)
{ {
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment