Commit 6b1ea2a1 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

remove unused files (#143)

parent 8156e3e0
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
#include "log.hpp"
NGraph* create_ngraph_object()
{
return new NGraph();
}
void destroy_ngraph_object(NGraph* pObj)
{
delete pObj;
}
void NGraph::add_params(const std::vector<std::string>& paramList)
{
NGRAPH_INFO << "Adding parameters";
m_params.insert(m_params.end(), paramList.begin(), paramList.end());
}
const std::vector<std::string>& NGraph::get_params() const
{
return m_params;
}
#pragma once
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <string>
#include <vector>
class NGraph
{
public:
void add_params(const std::vector<std::string>& paramList);
const std::vector<std::string>& get_params() const;
std::string get_name() const { return "NGraph Implementation Object"; }
private:
std::vector<std::string> m_params;
};
// Factory methods
extern "C" NGraph* create_ngraph_object();
extern "C" void destroy_ngraph_object(NGraph* pObj);
// FUnction pointers to the factory methods
typedef NGraph* (*CreateNGraphObjPfn)();
typedef void (*DestroyNGraphObjPfn)(NGraph*);
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "names.hpp"
using namespace ngraph;
size_t NameableValue::__counter = 0;
std::map<std::string, NameableValue> NameableValue::__all_names;
NameableValue::NameableValue(const std::string& name,
const std::string& graph_label_type,
const std::string& doc_string)
: m_name{name}
, m_doc_string{doc_string}
{
auto glt = m_name;
if (graph_label_type.size() > 0)
{
glt = graph_label_type;
}
{
std::stringstream ss;
ss << glt << "[" << m_name << "]";
m_graph_label = ss.str();
}
}
const std::string& NameableValue::graph_label()
{
return m_graph_label;
}
const std::string& NameableValue::name()
{
return m_name;
}
void NameableValue::name(const std::string& name)
{
// if name == type(self).__name__ or name in NameableValue.__all_names:
// while True:
// c_name = "{}_{}".format(name, type(self).__counter)
// if c_name not in NameableValue.__all_names:
// name = c_name
// break
// type(self).__counter += 1
// NameableValue.__all_names[name] = self
// self.__name = name
}
const std::string& NameableValue::short_name()
{
// sn = self.name.split('_')[0]
// if sn.find('.') != -1:
// sn = sn.split('.')[1]
// return sn
static const std::string x = "unimplemented";
return x;
}
NameableValue& NameableValue::named(const std::string& name)
{
m_name = name;
return *this;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <map>
#include <string>
namespace ngraph
{
//================================================================================================
// NameableValue
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
//
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// calculate their lengths.
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
//
// TODO: Please add to the list...
//
// Arguments:
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
class NameableValue
{
public:
//!-----------------------------------------------------------------------------------
//! NameableValue
//! An object that can be named.
//!
//! Arguments:
//! graph_label_type: A label that should be used when drawing the graph. Defaults to
//! the class name.
//! name (str): The name of the object.
//! **kwargs: Parameters for related classes.
//!
//! Attributes:
//! graph_label_type: A label that should be used when drawing the graph.
//! id: Unique id for this object.
//!-----------------------------------------------------------------------------------
NameableValue(const std::string& name,
const std::string& graph_label_type = "",
const std::string& doc_string = "");
//!-----------------------------------------------------------------------------------
//! graph_label
//! The label used for drawings of the graph.
//!-----------------------------------------------------------------------------------
const std::string& graph_label();
//!-----------------------------------------------------------------------------------
//! name
//! Sets the object name to a unique name based on name.
//!
//! Arguments:
//! name: Prefix for the name
//!-----------------------------------------------------------------------------------
const std::string& name();
//!-----------------------------------------------------------------------------------
//! name
//!-----------------------------------------------------------------------------------
void name(const std::string& name);
//!-----------------------------------------------------------------------------------
//! short_name
//!-----------------------------------------------------------------------------------
const std::string& short_name();
//!-----------------------------------------------------------------------------------
//! named
//!-----------------------------------------------------------------------------------
NameableValue& named(const std::string& name);
static size_t __counter;
static std::map<std::string, NameableValue> __all_names;
std::string m_name;
std::string m_graph_label;
std::string m_short_name;
std::string m_doc_string;
};
} // end namespace ngraph
#include <algorithm>
#include <iostream>
#include "strides.hpp"
#include "util.hpp"
using namespace std;
//================================================================================================
//
//================================================================================================
ngraph::tensor_size::tensor_size()
: m_tree{}
, m_element_type{element_type_float}
{
}
ngraph::tensor_size::tensor_size(size_t s, ElementType et)
: m_tree{s}
, m_element_type{et}
{
}
ngraph::tensor_size::tensor_size(const std::initializer_list<scalar_tree>& list, ElementType et)
: m_tree{list}
, m_element_type{et}
{
}
ngraph::tensor_size::tensor_size(const std::vector<size_t>& list, const ElementType& et)
: m_tree{list}
, m_element_type{et}
{
}
ngraph::tensor_stride ngraph::tensor_size::full_strides() const
{
tensor_stride result{*this};
vector<size_t*> value_pointer_list;
vector<size_t> size_list;
scalar_tree::traverse_tree(result.m_tree, [&](size_t* value) {
value_pointer_list.push_back(value);
size_list.push_back(*value);
});
int index = value_pointer_list.size() - 1;
*value_pointer_list[index] = result.m_element_type.size();
for (index--; index >= 0; index--)
{
*value_pointer_list[index] = *value_pointer_list[index + 1] * size_list[index + 1];
}
return result;
}
ngraph::tensor_stride ngraph::tensor_size::strides() const
{
return full_strides().strides();
}
ngraph::tensor_size ngraph::tensor_size::sizes() const
{
vector<size_t> tmp;
if (m_tree.is_list())
{
for (auto s : m_tree.get_list())
{
tmp.push_back(s.reduce([](size_t a, size_t b) { return a * b; }));
}
}
else
{
tmp.push_back(m_tree.get_value());
}
return tensor_size(tmp, m_element_type);
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::tensor_size& s)
{
out << s.m_tree;
return out;
}
//================================================================================================
//
//================================================================================================
ngraph::tensor_stride::tensor_stride()
: m_tree{}
, m_element_type{element_type_float}
{
}
ngraph::tensor_stride::tensor_stride(const tensor_size& s)
: m_tree{}
, m_element_type{s.m_element_type}
{
m_tree = s.m_tree;
}
ngraph::tensor_stride::tensor_stride(const std::vector<size_t>& list, const ElementType& et)
: m_tree{}
, m_element_type{et}
{
m_tree = list;
}
ngraph::tensor_stride ngraph::tensor_stride::reduce_strides() const
{
vector<size_t> tmp;
if (m_tree.is_list())
{
for (auto s : m_tree.get_list())
{
tmp.push_back(s.reduce([](size_t a, size_t b) { return min(a, b); }));
}
}
else
{
tmp.push_back(m_tree.get_value());
}
return tensor_stride(tmp, m_element_type);
}
ngraph::tensor_stride ngraph::tensor_stride::full_strides() const
{
return *this;
}
ngraph::tensor_stride ngraph::tensor_stride::strides() const
{
return reduce_strides();
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::tensor_stride& s)
{
out << s.m_tree;
return out;
}
#pragma once
#include <cstdio>
#include <initializer_list>
#include <vector>
#include "element_type.hpp"
#include "tree.hpp"
namespace ngraph
{
class tensor_size;
class tensor_stride;
}
//================================================================================================
//
//================================================================================================
class ngraph::tensor_size
{
friend class tensor_stride;
public:
tensor_size();
tensor_size(size_t s, ElementType et = element_type_float);
tensor_size(const std::initializer_list<scalar_tree>& list,
ElementType et = element_type_float);
const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const;
tensor_stride strides() const;
tensor_size sizes() const;
tensor_size operator[](size_t index) const;
friend std::ostream& operator<<(std::ostream& out, const tensor_size& s);
private:
tensor_size(const std::vector<size_t>&, const ElementType&);
scalar_tree m_tree;
ElementType m_element_type;
};
//================================================================================================
//
//================================================================================================
class ngraph::tensor_stride
{
friend class tensor_size;
public:
tensor_stride();
const ElementType& get_type() const { return m_element_type; }
tensor_stride full_strides() const;
tensor_stride strides() const;
tensor_stride reduce_strides() const;
tensor_stride operator[](size_t index) const;
friend std::ostream& operator<<(std::ostream& out, const tensor_stride& s);
private:
tensor_stride(const tensor_size&);
tensor_stride(const std::vector<size_t>&, const ElementType&);
scalar_tree m_tree;
ElementType m_element_type;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cmath>
#include <iostream>
#include <sstream>
#include "axes.hpp"
#include "util.hpp"
using namespace ngraph;
slice::slice(int64_t start, int64_t stop, int64_t step)
: m_start{(size_t)start}
, m_stop{(size_t)stop}
, m_step{step}
{
if (step != 1 && step != -1)
{
throw std::invalid_argument("slice step must be 1 or -1");
}
if (start == -1)
{
m_start = 0;
}
if (stop == -1)
{
m_stop = std::numeric_limits<size_t>::max();
}
if (m_step > 0)
{
m_start = std::min<size_t>(m_start, m_stop);
}
else
{
m_start = std::max<size_t>(m_start, m_stop);
}
}
size_t slice::sliced_length(size_t length) const
{
size_t start = m_start;
size_t stop = std::min<size_t>(m_stop, length);
size_t rc;
if (m_step == 1)
{
rc = std::max<size_t>(stop - start, 0);
}
else if (m_step == -1)
{
rc = std::max<size_t>(m_start - m_stop, 0);
}
else
{
throw std::runtime_error("slice step must be 1 or -1");
}
return rc;
}
// def default_dtype(dtype=None):
// if dtype is None:
// dtype = np.dtype(np.float32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try:
// dtype = np.dtype(dtype)
// except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
// def default_int_dtype(dtype=None):
// if dtype is None:
// dtype = np.dtype(np.int32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try:
// dtype = np.dtype(dtype)
// except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
Axis ngraph::make_axis(size_t length, const std::string& name, bool batch, bool recurrent)
{
return Axis(length, name);
}
Axes ngraph::make_axes(const std::vector<Axis>& axis_list)
{
return Axes(axis_list);
}
//================================================================================================
// Axis
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
//
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// calculate their lengths.
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
//
// TODO: Please add to the list...
//
// Arguments:
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
size_t Axis::__name_counter = 0;
Axis::Axis()
: Axis(0, "")
{
}
Axis::Axis(size_t length, const std::string& new_name)
: name{new_name}
, uuid{uuid_type()}
, __length{length}
{
if (name.size() == 0)
{
std::stringstream ss;
ss << "Axis_" << __name_counter++;
name = ss.str();
}
}
bool Axis::is_flattened() const
{
return false;
}
bool Axis::is_batch() const
{
return name == "N";
}
bool Axis::is_recurrent() const
{
return name == "REC";
}
bool Axis::is_channel() const
{
return name == "C";
}
size_t Axis::length() const
{
return __length;
}
void Axis::length(size_t l)
{
__length = l;
}
std::ostream& ngraph::operator<<(std::ostream& out, const Axis& axis)
{
out << axis.to_string();
return out;
}
std::string Axis::to_string() const
{
std::stringstream ss;
ss << "Axis(" << name << ": " << length() << ")";
return ss.str();
}
bool Axis::operator==(const Axis& other) const
{
return name == other.name;
}
bool Axis::operator!=(const Axis& other) const
{
return !(*this == other);
}
bool Axis::operator<(const Axis& other) const
{
bool rc;
if (this->name == other.name)
{
rc = this->length() < other.length();
}
else
{
rc = this->name < other.name;
}
return rc;
}
// def _sliced_length(s, incoming_length):
// start, stop, step = s.indices(incoming_length)
// # max with 0 so we dont ever return a negative length. This
// # matches how python handles it internally. Raising an exception
// # might also be reasonable.
// if step == 1:
// return max(stop - start, 0)
// elif step == -1:
// return max(start - stop, 0)
// else:
// _validate_slice(s)
// // def _validate_slice(s):
// void validate_slice(const slice& s)
// {
// // if s.step not in (-1, 1, None):
// // raise ValueError((
// // 'SlicedAxis cant currently handle a step size other '
// // 'than -1, 1 or None. Was given {step} in slice {slice}'
// // ).format(
// // step=s.step,
// // slice=s,
// // ))
// }
Axis ngraph::slice_axis(const Axis& axis, const slice& s)
{
// _validate_slice(s)
// # get sliced length
// new_length = None if axis.length is None else _sliced_length(s, axis.length)
auto new_length = s.sliced_length(axis.length());
// # create sliced axis
// new_axis = make_axis(length=new_length,
// name=axis.name)
return make_axis(new_length, axis.name);
// return new_axis
}
// def duplicates(arr):
// """
// Returns a list of Axis objects which have duplicate names in arr
// Arguments:
// arr: The iterable of Axis objects to check for duplicates in.
// Returns:
// list of Axis: duplicate Axis found in arr
// """
std::vector<std::string> ngraph::duplicates(const std::vector<Axis>& ax)
{
std::map<std::string, size_t> counts;
std::vector<std::string> rc;
for (const Axis& axis : ax)
{
auto it = counts.find(axis.name);
if (it == counts.end())
{
counts.insert({axis.name, 1});
}
else
{
it->second++;
}
}
for (auto p : counts)
{
if (p.second > 1)
{
rc.push_back(p.first);
}
}
return rc;
}
// def with_args_as_axes(f):
// """
// A decorator to cast arguments to axes.
// Arguments:
// f: The function to be decorated.
// Returns:
// The decorated function.
// """
// @wraps(f)
// def wrapper(*args):
// """
// The decorated function. Performs the conversion
// to Axes.
// Arguments:
// *args: Arguments intended for the original function.
// Returns:
// Return value of the original function.
// """
// args = [Axes(arg) for arg in args]
// return f(*args)
// return wrapper
//================================================================================================
// Axes
// An Axes is a tuple of Axis objects used as a label for a tensor's
// dimensions.
//================================================================================================
Axes::Axes()
: uuid{}
{
}
Axes::Axes(const Axis& axis)
{
axes.push_back(axis);
}
Axes::Axes(const std::vector<Axis>& axis_list)
{
axes = axis_list;
check_duplicates();
}
Axes::Axes(const std::initializer_list<Axes>& list)
{
axes = convert(std::vector<Axes>(list));
check_duplicates();
}
size_t Axes::size() const
{
return axes.size();
}
void Axes::check_duplicates()
{
auto dups = duplicates(axes);
if (dups.size() > 0)
{
std::stringstream ss;
ss << "The axes labels of a tensor cannot contain duplicates. Found: " << join(dups, ", ");
throw std::invalid_argument(ss.str());
}
}
const Axis& Axes::operator[](size_t index) const
{
return axes[index];
}
// Axis Axes::operator[](const slice&) const
// {
// }
// class Axes(object):
// """
// """
// def __init__(self, axes=None):
// if axes is None:
// axes = []
// elif isinstance(axes, Axis):
// axes = [axes]
// elif isinstance(axes, types.GeneratorType):
// axes = tuple(axes)
// elif isinstance(axes, (list, tuple)) and not isinstance(axes, Axes):
// axes = tuple(axes)
// def convert(seq):
// axes = convert(axes)
// for x in axes:
// if not isinstance(x, Axis):
// raise ValueError((
// 'tried to initialize an Axes with object type '
// '{found_type}. all values should be an instance '
// 'of a type which inherits from Axis.'
// ).format(
// found_type=type(x),
// ))
// if duplicates(axes):
// raise ValueError(
// 'The axes labels of a tensor cannot contain duplicates. Found: {}'
// .format(str(duplicates(axes)))
// )
// self._axes = tuple(axes)
// self.uuid = uuid.uuid4()
// @property
// def full_lengths(self):
// """
// Returns all information about the lengths of the axis objects
// in this Axes in the form of a nested tuple. An element of the
// outer tuple that is itself a tuple contains the restored lengths
// of axes that have been flattened in this Axis object.
// Returns:
// tuple: A nested tuple with the axis lengths.
// """
// return tuple(x.axes.full_lengths if x.is_flattened
// else x.length for x in self)
// @property
// def names(self):
// """
// Returns:
// tuple: The names of the outer axes.
// """
// return tuple(x.name for x in self)
std::vector<Axis> Axes::convert(const Axes& ax)
{
std::vector<Axis> rc;
for (const Axis& axis : ax.axes)
{
rc.push_back(axis);
}
return rc;
}
std::vector<Axis> Axes::convert(const std::vector<Axes>& list)
{
std::vector<Axis> rc;
for (const Axes& ax : list)
{
if (ax.axes.size() == 1)
{
rc.push_back(ax.axes[0]);
}
else
{
std::vector<Axis> tmp = convert(ax);
Axes t1(tmp);
auto x = t1.flatten();
rc.push_back(x);
}
}
return rc;
}
std::vector<size_t> Axes::lengths() const
{
// return tuple(x.length for x in self)
std::vector<size_t> rc;
for (auto a : axes)
{
rc.push_back(a.length());
}
return rc;
}
// def batch_axes(self):
// """
// Returns:
// The tensor's batch Axis wrapped in an Axes object if there is one
// on this tensor, otherwise returns None
// """
// batch_axis = self.batch_axis()
// if batch_axis:
// return Axes([batch_axis])
// else:
// return None
// def batch_axis(self):
// """
// Returns:
// The tensor's batch Axis or None if there isn't one.
// """
// for axis in self:
// if axis.is_batch:
// return axis
// def channel_axis(self):
// """
// Returns:
// The tensor's batch Axis or None if there isn't one.
// """
// for axis in self:
// if axis.is_channel:
// return axis
// def spatial_axes(self):
// """
// Returns:
// The Axes subset that are not batch, recurrent, or channel axes.
// """
// return self.feature_axes() - self.channel_axis()
// def sample_axes(self):
// """
// Returns:
// The Axes subset that are not batch axes.
// """
// return Axes(axis for axis in self if not axis.is_batch)
// def feature_axes(self):
// """
// Returns:
// The Axes subset that are not batch or recurrent axes.
// """
// return Axes(axis for axis in self if not axis.is_batch and not axis.is_recurrent)
// def recurrent_axis(self):
// """
// Returns:
// The tensor's recurrent Axis or None if there isn't one.
// """
// for axis in self:
// if axis.is_recurrent:
// return axis
Axis Axes::flatten(bool force) const
{
Axis rc;
if (!force && axes.size() == 1)
{
rc = axes[0];
}
else
{
rc = FlattenedAxis(axes);
}
return rc;
}
// def set_shape(self, shape):
// """
// Set shape of Axes
// Args:
// shape: tuple or list of shapes, must be the same length as the axes
// """
// if len(shape) != len(self._axes):
// raise ValueError("shape's length %s must be equal to axes' length"
// "%s" % (len(shape), len(self)))
// for axis, length in zip(self._axes, shape):
// axis.length = length
// def find_by_name(self, name):
// return Axes(axis for axis in self if axis.name == name)
// def __iter__(self):
// return self._axes.__iter__()
// def __len__(self):
// return len(self._axes)
// def __getitem__(self, item):
// if isinstance(item, slice):
// return Axes(self._axes.__getitem__(item))
// else:
// return self._axes.__getitem__(item)
// def __getslice__(self, i, j):
// return self.__getitem__(slice(i, j))
Axes Axes::operator+(const Axes& other)
{
// other = make_axes(other)
// common_axes = self & other
Axes common_axes = *this & other;
if (common_axes.size() != 0)
{
std::stringstream ss;
ss << "Trying to concatenate " << *this << " with " << other;
ss << ", but they have common axes " << common_axes << ", which is not allowed.";
throw std::invalid_argument(ss.str());
}
std::vector<Axis> rc = axes;
rc.insert(rc.end(), other.axes.begin(), other.axes.end());
return Axes(rc);
}
Axes Axes::operator-(const Axes& other)
{
std::vector<Axis> axis_list;
for (const Axis& axis : axes)
{
if (!contains(other.axes, axis))
{
axis_list.push_back(axis);
}
}
return Axes(axis_list);
}
Axes Axes::operator|(const Axes& other)
{
std::vector<Axis> axis_list = axes;
for (const Axis& axis : other.axes)
{
if (!contains(axes, axis))
{
axis_list.push_back(axis);
}
}
return Axes(axis_list);
}
Axes Axes::operator&(const Axes& other)
{
std::vector<Axis> axis_list;
for (const Axis& axis : axes)
{
if (contains(other.axes, axis))
{
axis_list.push_back(axis);
}
}
return Axes(axis_list);
}
bool Axes::operator==(const Axes& other) const
{
bool rc = axes.size() == other.size();
if (rc)
{
for (int i = 0; i < axes.size(); i++)
{
if (axes[i] != other.axes[i])
{
rc = false;
break;
}
}
}
return rc;
}
bool Axes::operator!=(const Axes& other) const
{
return !(*this == other);
}
// def __nonzero__(self):
// """ Axes considered nonzero if axes are nonzero. """
// return bool(self._axes)
// def __hash__(self):
// return hash(self._axes)
bool Axes::is_sub_set(const Axes& other) const
{
bool rc = true;
for (const Axis& axis : other.axes)
{
if (!contains(this->axes, axis))
{
rc = false;
break;
}
}
return rc;
}
bool Axes::is_super_set(const Axes& other) const
{
bool rc = true;
for (const Axis& axis : axes)
{
if (!contains(other.axes, axis))
{
rc = false;
break;
}
}
return rc;
}
bool Axes::is_equal_set(const Axes& other) const
{
bool rc = axes.size() == other.axes.size();
for (const Axis& axis : axes)
{
if (!contains(other.axes, axis))
{
rc = false;
break;
}
}
return rc;
}
bool Axes::is_not_equal_set(const Axes& other) const
{
return !is_equal_set(other);
}
bool Axes::operator<(const Axes& other) const
{
int rc = false;
if (this->axes.size() == other.axes.size() && this->axes.size() > 0)
{
rc = this->axes[0] < other.axes[0];
}
else
{
rc = this->axes.size() < other.axes.size();
}
return rc;
}
// @property
// def T(self):
// return Axes(axis.T for axis in self)
// def index(self, axis):
// """
// Returns the index of an axis
// Arguments:
// axis: The axis to search for.
// Returns:
// The index.
// """
// return self._axes.index(axis)
// @staticmethod
// @with_args_as_axes
// def assert_valid_broadcast(axes, new_axes):
// """
// Checks whether axes can be broadcasted to new_axes. We require
// that the components of axes be laid out in the same order in new_axes.
// Axes:
// axes: The original axes.
// new_axes: The broadcasted axes.
// Returns:
// True if axes can be broadcasted to new_axes, False otherwise.
// """
// removed_axes = axes - new_axes
// if removed_axes:
// raise ValueError(("The new_axes of a broadcast operation must "
// "include all of the axes from the origional set "
// "of axes. \n"
// " original axes: {axes}\n"
// " new axes: {new_axes}\n"
// " missing axes: {removed_axes}").format(
// axes=axes,
// new_axes=new_axes,
// removed_axes=removed_axes,
// ))
// @staticmethod
// @with_args_as_axes
// def is_valid_flatten_or_unflatten(src_axes, dst_axes):
// """
// Checks whether we can flatten OR unflatten from src_axes to dst_axes.
// The requirements are that the components of axes should all be
// present in new_axes and that they should be laid out in the same
// order. This check is symmetric.
// """
// # inflate
// src_axes = Axes.as_flattened_list(src_axes)
// dst_axes = Axes.as_flattened_list(dst_axes)
// # check equal number of Axis
// if len(src_axes) != len(dst_axes):
// return False
// # check all Axis are equal
// equal = [src == dst for src, dst in zip(src_axes, dst_axes)]
// return all(equal)
// @staticmethod
// @with_args_as_axes
// def assert_valid_flatten(unflattend_axes, flattened_axes):
// """
// Checks whther axes can safely be flattened to produce new_axes.
// The requirements are that the components of axes should all be
// present in new_axes and that they should be laid out in the same
// order.
// Arguments:
// unflattend_axes: The original axes.
// flattened_axes: The flattened axes.
// Returns:
// True if axes can be safely flattened to new_axes, False otherwise.
// """
// if not Axes.is_valid_flatten_or_unflatten(unflattend_axes, flattened_axes):
// raise ValueError("Trying to flatten:\n%s\nto:\n%s.\n"
// "But they are of different lengths, or the axes"
// "layouts are different"
// % (unflattend_axes, flattened_axes))
// @staticmethod
// @with_args_as_axes
// def assert_valid_unflatten(flattened_axes, unflattend_axes):
// """
// Checks whether axes can safely be unflattened to produce new_axes.
// The requirements are that the components of axes should all be
// present in new_axes and that they should be laid out in the same
// order.
// Arguments:
// flattened_axes: The original axes.
// unflattend_axes: The unflattened axes.
// Returns:
// True if axes can be safely unflattened to new_axes, False otherwise.
// """
// if not Axes.is_valid_flatten_or_unflatten(flattened_axes, unflattend_axes):
// raise ValueError("Trying to unflatten:\n%s\nto:\n%s.\n"
// "But they are of different lengths, or the axes"
// "layouts are different"
// % (unflattend_axes, flattened_axes))
// @property
// def size(self):
// """
// TODO: delete this method, the size should come from the tensor
// """
// return int(np.prod(self.lengths))
std::ostream& ngraph::operator<<(std::ostream& out, const Axes& axes)
{
out << "Axes(";
out << join(axes.axes, ", ");
out << ")";
return out;
}
// def __repr__(self):
// return 'Axes({})'.format(
// ', '.join(map(repr, self))
// )
// def __str__(self):
// return ', '.join(map(str, self))
// @staticmethod
// def as_nested_list(axes):
// """
// Converts Axes to a list of axes with flattened axes expressed as nested lists
// Returns:
// Nested list of Axis objects
// """
// if isinstance(axes, (Axes, list)):
// return [Axes.as_nested_list(a) for a in axes]
// elif isinstance(axes, FlattenedAxis):
// return [Axes.as_nested_list(a) for a in axes.axes]
// elif isinstance(axes, Axis):
// return axes
// @staticmethod
// def as_flattened_list(axes):
// """
// Converts Axes to a list of axes with flattened axes expanded recursively.
// Returns:
// List of Axis objects
// """
// axes_list = [list(axis.axes) if axis.is_flattened else [axis]
// for axis in axes]
// axes = list(itertools.chain.from_iterable(axes_list))
// # inflate recursively
// if any([axis.is_flattened for axis in axes]):
// return Axes.as_flattened_list(axes)
// else:
// return axes
//================================================================================================
// DuplicateAxisNames
//================================================================================================
// class DuplicateAxisNames(ValueError):
// def __init__(self, message, duplicate_axis_names):
// super(DuplicateAxisNames, self).__init__(message)
// self.duplicate_axis_names = duplicate_axis_names
//================================================================================================
// IncompatibleAxesError
//================================================================================================
// class IncompatibleAxesError(ValueError):
// pass
//================================================================================================
// UnmatchedAxesError
//================================================================================================
// class UnmatchedAxesError(IncompatibleAxesError):
// pass
//================================================================================================
// AxesMap
// AxesMap provides a way to define a axis name mapping: {Axis.name: Axis.name} and
// then apply this mapping to an Axes and get new Axes out.
//
// Right now AxesMap is implemented as immutible because I didn't want to deal with
// enforcing _assert_valid_axes_map on every method which mutates a dict and I didn't
// need a mutable datastructure anyway. Feel free to make it mutable and add in
// invariant enforcement.
//================================================================================================
AxesMap::AxesMap(const std::pair<std::string, std::string>& p)
{
this->insert(p);
}
AxesMap::AxesMap(std::initializer_list<std::pair<std::string, std::string>> list)
{
this->insert(list.begin(), list.end());
assert_valid_axes_map();
}
// def __init__(self, *args, **kwargs):
// def replace_axis_with_name(x):
// if isinstance(x, Axis):
// return x.name
// return x
// # strip axis objects into just names
// super(AxesMap, self).__init__({
// replace_axis_with_name(k): replace_axis_with_name(v)
// for k, v in dict(*args, **kwargs).items()
// })
// self._assert_valid_axes_map()
Axes AxesMap::map_axes(const Axes& ax) const
{
std::vector<Axis> mapped_list;
for (const Axis& axis : ax)
{
mapped_list.push_back(map_axis(axis));
}
return make_axes(mapped_list);
}
Axis AxesMap::map_axis(const Axis& old_axis) const
{
Axis rc = old_axis;
if (contains_key(*this, old_axis.name))
{
rc = make_axis(old_axis.length(), this->at(old_axis.name));
}
return rc;
}
std::map<std::string, std::set<std::string>> AxesMap::duplicate_axis_names()
{
std::map<std::string, std::set<std::string>> counts;
for (auto p : *this)
{
counts[p.second].insert(p.first);
}
std::map<std::string, std::set<std::string>> rc;
for (auto p : counts)
{
if (p.second.size() > 1)
{
rc.insert(p);
}
}
return rc;
}
void AxesMap::assert_valid_axes_map()
{
auto duplicate_names = duplicate_axis_names();
// if there are duplicate_axis_names throw an exception
if (duplicate_names.size() > 0)
{
std::stringstream ss;
ss << "AxesMap can not have duplicate names, but found:";
for (auto p : duplicate_names)
{
ss << "\n " << p.first << " maps to " << join(p.second, ", ");
}
throw std::invalid_argument(ss.str());
}
}
// def invert(self):
// return {v: k for k, v in self.items()}
// def _reduce_nested(elem, agg, func):
// """
// Reduces a nested sequence by applying a function to each
// of its elements and returns an aggregation.
// Arguments:
// elem: The object to be reduced, either a sequence
// or a singleton.
// agg: A variable holding information collected
// as the sequence is collapsed.
// func: A function to augment the aggregate by processing
// a singleton. Should have the form func(agg, elem) -> agg
// Returns:
// agg: The final aggregate returned by the function.
// """
// if isinstance(elem, collections.Iterable):
// for sub in elem:
// agg = _reduce_nested(sub, agg, func)
// return agg
// else:
// return func(agg, elem)
//================================================================================================
// FlattenedAxis
// A FlattenedAxis has length which is the product of the lengths of all
// Axis in the axes. The original Axes object is stored so that we can later
// unflatten this Axis back to its original component Axis.
//
// Notes: since we allows Axis to have duplicated names globally, NameableValue
// is not used here.
//================================================================================================
FlattenedAxis::FlattenedAxis(const std::vector<Axis>& list, const std::string& new_name)
{
// get length
Axes ax(list);
// if len(axes) == 1 and axes[0].is_flattened:
// pass
// length = reduce(operator.mul, axes.lengths, 1)
auto lengths = ax.lengths();
__length = ngraph::reduce(lengths.begin(), lengths.end(), ngraph::mul<size_t>);
// # set name
// name = '%s_%s' % (type(self).__name__, type(self).__name_counter)
// type(self).__name_counter += 1
name = new_name;
if (name.size() == 0)
{
std::stringstream ss;
ss << "Axis_" << __name_counter++;
name = ss.str();
}
// # parent constructor
// super(FlattenedAxis, self).__init__(length=length, name=name, **kwargs)
// self._axes = axes
axes = list;
}
std::ostream& ngraph::operator<<(std::ostream& out, const FlattenedAxis& obj)
{
out << obj.to_string();
return out;
}
std::string FlattenedAxis::to_string() const
{
std::stringstream ss;
ss << "FlattenedAxis(" << join(axes, ", ") << ")";
return ss.str();
}
// def _make_stride(inner_size, axis, fsz):
// """
// Generates a nested tuple that provides the striding information
// for an occurrence of axis. If the axis is a FlattenedAxis, the
// stride will be a tuple containing the strides of each collapsed
// axis. Otherwise, the stride will be an integer.
// Arguments:
// inner_size: The total size of all dimensions smaller than this
// axis, i.e. all axes to the right of this one when they are
// laid out in c-contiguous order.
// axis: The axis for which we are generating a stride.
// fsz: A nested tuple supplying the sizes of each dimension collapsed
// into the axis. The size may be larger than the length of the axis.
// Returns:
// inner_size: The total size of this axis and all smaller dimensions.
// stride: The stride given to the axis.
// """
// if axis.is_flattened:
// return _make_strides(inner_size, axis.axes, fsz)
// else:
// stride = inner_size
// inner_size *= fsz
// return inner_size, stride
// def _make_strides(inner_size, axes, full_sizes):
// """
// Generates a tuple of strides for a set of axes. See _make_stride
// for a description of the stride given to each axis.
// Arguments:
// inner_size: The total size of all dimensions smaller than
// the axes.
// axes: The axes for which we are generating strides.
// full_sizes: The size of each axis.
// Returns:
// inner_size: The total size of these axes and all smaller dimensions.
// strides: The strides generated for the axes.
// """
// full_strides = []
// for axis, fsz in reversed(list(zip(axes, full_sizes))):
// inner_size, stride = _make_stride(inner_size, axis, fsz)
// full_strides.append(stride)
// return inner_size, tuple(reversed(full_strides))
//================================================================================================
// TensorDescription
// Description of a tensor that will be allocated in hardware.
//
// Names the tensor's dimensions with axes and holds pointers to the
// buffer allocated by the analysis and the backend tensor value
// (e.g. a cpu or gpu tensor).
//
// Arguments:
// axes: Axes of the tensor.
// base: If a view, the viewed tensor's description.
// dtype: The type of the tensor.
// full_strides: The strides of each axis.
// full_sizes: The allocated size of each axis (may be larger than the axis).
// offset: An offset into the viewed tensor.
// next_tensor_decription: In a reshape, tensor description of reshaped tensor.
// is_persistent: The tensor should be persistent, i.e. survive from computation to
// computation.
// is_input: The device tensor can be written from the host.
// **kwargs: Additional args for related classes.
//================================================================================================
TensorDescription::TensorDescription(op_ptr _op,
const Axes& _axes,
tensor_description_ptr base,
// layout,
ElementType et,
ngraph::tensor_stride _full_strides,
ngraph::tensor_size _full_sizes,
size_t _offset,
TensorDescription* next_tensor_decription,
const std::string& _name,
bool is_persistent,
bool is_input,
bool is_placeholder)
: NameableValue(_name)
, op{_op}
, axes{_axes}
, __is_persistent{is_persistent}
, __is_input{is_input}
, __is_placeholder{is_placeholder}
, __base{base}
, dtype{et}
, full_sizes{_full_sizes}
, full_strides{_full_strides}
{
// super(TensorDescription, self).__init__(**kwargs)
// # TODO: get the default type from the backend. May not always be numpy.
// # TODO: support flattening, unflattening, other complex reshapes
// axes = Axes(axes)
// self.axes = axes
// self.__layout = layout
// self.__value = None
// self.__buffer = None
// self.__register = None
// self.__base = base
// self.dtype = default_dtype(dtype)
// self.offset = offset
// self.ndim = len(self.axes)
// self.full_sizes = tuple(full_sizes) if full_sizes is not None \
// else self.axes.full_lengths
// self.next_tensor_description = next_tensor_description
// self.__is_persistent = is_persistent
// self.__is_input = is_input
// self.__is_placeholder = is_placeholder
// self.op = op
// if not isinstance(self.name, str):
// raise ValueError()
// for axis in axes:
// if axis.length is None:
// raise ValueError((
// 'axes used in the constructor of TensorDescription must '
// 'always have non-None length. Axis {axis} has length '
// 'None.'
// ).format(axis=axis))
// if full_strides is None:
// _, full_strides = _make_strides(
// self.dtype.itemsize,
// self.axes,
// self.full_sizes
// )
// self.full_strides = full_strides
// else:
// self.full_strides = tuple(full_strides)
// assert len(self.full_sizes) == self.ndim, \
// "Sizes must have same number of dimensions as axes"
// assert len(self.full_strides) == self.ndim, \
// "Strides must have same number of dimensions as axes"
}
// def __repr__(self):
// return self.base.name
ElementType TensorDescription::element_type() const
{
return dtype;
}
bool TensorDescription::is_persistent() const
{
return __is_persistent;
}
bool TensorDescription::is_input() const
{
return __is_input;
}
bool TensorDescription::is_placeholder() const
{
return __is_placeholder;
}
// @property
// def parameter_key(self):
// """
// Returns: A tuple that can be used to tell if two views of a tensor are equivalent.
// """
// return (self.shape, self.dtype, self.offset, self.strides, self.layout)
// @property
axes_key_t TensorDescription::axes_key() const
{
std::hash<Axes> axes_hash;
std::hash<size_t> offset_hash;
// std::hash<decltype(strides)> strides_hash;
// std::hash<decltype(layout)> layout_hash;
std::vector<size_t> hash_list;
hash_list.push_back(axes_hash(axes));
// hash_list.push_back(hash_combine(shape));
hash_list.push_back(dtype.hash());
hash_list.push_back(offset_hash(offset));
// TODO: add strides and layout to hash
// def axes_key(self):
// return (self.axes, self.shape, self.dtype, self.offset, self.strides, self.layout)
return hash_combine(hash_list);
};
// def flatten(self, new_axes):
// """
// Flattens a tensor description to give it the Axes in new_axes.
// See Axes.assert_valid_flatten for a description of permitted values of new_axes.
// Arguments:
// new_axes: The Axes of the flattened tensor description.
// Returns:
// The reshaped tensor description.
// """
// new_axes = Axes(new_axes)
// Axes.assert_valid_flatten(self.axes, new_axes)
// new_strides = []
// new_sizes = []
// idx = 0
// for new_axis in new_axes:
// if new_axis == self.axes[idx]:
// new_stride = self.full_strides[idx]
// new_size = self.full_sizes[idx]
// idx += 1
// else:
// l = len(new_axis.axes)
// new_stride = self.full_strides[idx:idx + l]
// new_size = self.full_sizes[idx:idx + l]
// idx += l
// new_strides.append(new_stride)
// new_sizes.append(new_size)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=new_strides,
// full_sizes=new_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rFlatten',
// )
// def unflatten(self, new_axes):
// """
// Unflattens a tensor description to give it the Axes in new_axes.
// See Axes.assert_valid_unflatten for a description of the permitted values of
// new_axes
// Arguments:
// new_axes: The Axes of the unflattened TensorDescription.
// Returns:
// The unflattened tensor description.
// """
// def find_axis_stride_and_length(axis):
// """
// Find the stride and length for an axis.
// Start at the current tensor description and then work back
// through reshapings of it looking for a mention of the axis
// that can be used to determine the storage stride and offset.
// Args:
// axis: The axis.
// Returns:
// stride, length of axis
// """
// td = self
// while td is not None:
// for idx, a in enumerate(td.axes):
// # Try to find a match for axis in this td
// full_strides = td.full_strides[idx]
// full_sizes = td.full_sizes[idx]
// if a == axis:
// return full_strides, full_sizes
// if a.is_flattened:
// # Can be embedded ina a flattened axis description
// if not isinstance(full_strides, tuple):
// # An axis cast can lose striding info, so need to
// # recreate it from the axis lengths. Being flattened
// # implies C-contiguous
// stride = full_strides
// full_strides = []
// full_sizes = []
// for s in reversed(a.axes):
// full_sizes.insert(0, s.length)
// full_strides.insert(0, stride)
// stride = stride * s.length
// # Now search for axis in the flattened axis
// for sub_idx, b in enumerate(a.axes):
// if b == axis:
// return full_strides[sub_idx], full_sizes[sub_idx]
// # Move on to the next tensor description in the reshaping chain
// td = td.next_tensor_description
// # Sometimes we just don't have enough information.
// raise ValueError()
// new_axes = Axes(new_axes)
// Axes.assert_valid_unflatten(self.axes, new_axes)
// new_strides = []
// new_sizes = []
// for new_axis in new_axes:
// stride, size = find_axis_stride_and_length(new_axis)
// new_strides.append(stride)
// new_sizes.append(size)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=new_strides,
// full_sizes=new_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rUnflatten',
// )
// def transpose(self):
// """
// Reverses the axes of the tensor description.
// Retuns:
// A tensor description with the axes reversed.
// """
// new_axes = reversed(self.axes)
// full_sizes = reversed(self.full_sizes)
// full_strides = reversed(self.full_strides)
// return TensorDescription(
// Axes(new_axes),
// base=self.base,
// dtype=self.dtype,
// full_strides=tuple(full_strides),
// full_sizes=tuple(full_sizes),
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rTranspose',
// )
// def clone(self):
// """
// Creates a copy of this tensor description
// Retuns:
// A copy of this tensor description
// """
// return TensorDescription(
// self.axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=self.full_strides,
// full_sizes=self.full_sizes,
// offset=self.offset,
// next_tensor_description=self.next_tensor_description,
// name=self.name + 'cView',
// )
// TensorDescription TensorDescription::broadcast(const Axes& new_axes)
// {
// Axes::assert_valid_broadcast(axes, new_axes);
// return reorder_and_broadcast(new_axes);
// }
// Axes.assert_valid_broadcast(self.axes, new_axes)
// return self.reorder_and_broadcast(new_axes)
// def reorder(self, new_axes):
// """
// Shuffles axes of a tensor to give it a new shape. The axes of
// this tensor description and new_axes must have the same elements.
// Arguments:
// new_axes: The axes of the reordered tensor.
// Returns:
// TensorDescription: The reordered tensor description.
// """
// if not self.axes.is_equal_set(new_axes):
// raise ValueError((
// "Reorder can't change which axes are available, only the "
// "order. {} and {} are different sets, not just order."
// ).format(self, new_axes))
// return self.reorder_and_broadcast(new_axes)
// def reorder_and_broadcast(self, new_axes):
// """
// Adds or shuffles axes to give a tensor description a new shape.
// This function is used to implement broadcast and reorder.
// Arguments:
// new_axes: The axes of the broadcasted or reordered tensor.
// Returns:
// TensorDescription: A description of the tensor after the
// transformation.
// """
// def zero_in_shape(tup):
// zero_in_shape()
// {
// // if isinstance(tup, collections.Iterable):
// // return tuple(
// // zero_in_shape(t) for t in tup
// // )
// // else:
// // return 0
// }
// TensorDescription TensorDescription::reorder_and_broadcast(const Axes& _new_axes)
// {
// // new_axes = Axes(new_axes)
// auto new_axes = Axes(_new_axes);
// // new_strides = []
// std::vector<size_t> new_strides;
// // new_sizes = []
// std::vector<size_t> new_sizes;
// // for axis in new_axes:
// for (const Axis& axis : new_axes)
// {
// // if axis in self.axes:
// if (contains(axes, axis))
// {
// // idx = self.axes.index(axis)
// auto idx = axes.index(axis);
// // new_strides.append(self.full_strides[idx])
// new_strides.push_back(full_strides[idx]);
// // new_sizes.append(self.full_sizes[idx])
// new_sizes.push_back(full_sizes[idx]);
// }
// // elif axis.is_flattened:
// else if (axis.is_flattened())
// {
// // lengths = axis.axes.full_lengths
// auto lengths = axis.axes().full_lengths();
// // new_strides.append(zero_in_shape(lengths))
// new_strides.push_back(zero_in_shape(lengths));
// // new_sizes.append(lengths)
// new_sizes.push_back(lengths);
// }
// // else:
// else
// {
// // new_strides.append(0)
// new_strides.push_back(0);
// // new_sizes.append(axis.length)
// new_sizes.push_back(axis.length());
// }
// }
// return TensorDescription(
// nullptr,
// new_axes,
// base,
// dtype,
// new_strides,
// new_sizes,
// offset,
// this,
// name() + "rReorderBroadcast"
// );
// }
// def cast(self, new_axes):
// """
// Return a tensor desciption for a view of the tensor.
// Arguments:
// new_axes: The axes for the view.
// Returns:
// The tensor description.
// """
// full_strides = self.full_strides
// full_sizes = self.full_sizes
// if self.ndim == 0:
// full_strides = (0,) * len(new_axes)
// full_sizes = new_axes.full_lengths
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=full_strides,
// full_sizes=full_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rCast',
// )
// def slice(self, slices, new_axes):
// """
// Return a tensor description for a slice view of this tensor.
// Arguments:
// slices: The slices to take from the tensor, each of which is
// either an integer or a python slice. If the input has too few
// axes for the tensor, we assume that the entire axis should be
// taken for dimensions towards the end of the tensor.
// new_axes: the axes to use as labels for the sliced tensor.
// Returns:
// The tensor description for the slice.
// """
// slices = list(slices)
// while len(slices) < self.ndim:
// slices.append(slice(None))
// offset = self.offset
// full_strides = []
// full_sizes = []
// new_index = 0
// # check new_axes for the correct length
// num_dimensions_out = len([s for s in slices if isinstance(s, slice)])
// if len(new_axes) != num_dimensions_out:
// raise ValueError((
// 'in a slice operation, the number of axes passed in to '
// 'new_axes ({num_new_axes}) must be the same as the number of '
// 'slice objects in slices ({num_slices}).'
// ).format(
// num_new_axes=len(new_axes),
// num_slices=num_dimensions_out,
// ))
// for s, axis, stride, size in zip(slices, self.axes, self.strides, self.sizes):
// if isinstance(s, slice):
// # only increment new_axis when the input slice is a slice and
// # not a integer
// new_axis = new_axes[new_index]
// new_index += 1
// # ensure slice is of the kind we support
// _validate_slice(s)
// # ensure new_axis has the correct length
// new_axis.length = _sliced_length(s, axis.length)
// start, stop, step = s.indices(axis.length)
// full_strides.append(stride * step)
// full_sizes.append(size)
// idx = start
// else:
// # this is a simple integer slice, ex: y = x[1]
// idx = s
// # TODO: write a test that fails if abs() is removed
// offset += idx * abs(stride)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=tuple(full_strides),
// full_sizes=tuple(full_sizes),
// offset=offset,
// next_tensor_description=self,
// name=self.name + "rSlice",
// )
std::vector<size_t> TensorDescription::shape() const
{
return axes.lengths();
}
// @property
// def strides(self):
// """The strides of the tensor."""
// return reduce_strides(self.full_strides)
// @property
// def sizes(self):
// """The allocated sizes for each axis."""
// return tuple(_reduce_nested(_, 1, operator.mul)
// for _ in self.full_sizes)
size_t TensorDescription::tensor_size() const
{
throw std::runtime_error("unimplemented");
}
// result = self.dtype.itemsize
// for s in self.sizes:
// result = result * s
// return result
// @property
// def c_contiguous(self):
// """
// Returns:
// True if the tensor's strides are row-major contiguous.
// """
// s = self.dtype.itemsize
// cstrides = []
// for _ in reversed(self.shape):
// cstrides.insert(0, s)
// s = s * _
// return tuple(cstrides) == self.strides
// @property
// def broadcast_contiguous(self):
// """
// Returns:
// True if tensor's strides are contiguous or broadcasted
// """
// if self.shape == ():
// return True
// broadcast_axes = np.where(np.equal(self.strides, 0))[0]
// aug_shape = list(self.shape)
// for bcast_axis in broadcast_axes:
// aug_shape[bcast_axis] = 1
// s = self.dtype.itemsize
// cstrides = []
// for _ in reversed(aug_shape):
// cstrides.insert(0, s)
// s = s * _
// for bcast_axis in broadcast_axes:
// cstrides[bcast_axis] = 0
// return tuple(cstrides) == self.strides
tensor_description_ptr TensorDescription::base() const
{
return __base;
}
// @property
// def layout(self):
// """The layout of the underlying storage."""
// return self.__layout
// @layout.setter
// def layout(self, value):
// """
// Sets the backend-specific memory layout to be used by the tensor.
// Arguments:
// value: the layout to use
// Returns:
// """
// self.__layout = value
// @property
// def register(self):
// return self.base.__register
// @register.setter
// def register(self, value):
// self.base.__register = value
// def is_base(self):
// """This tensor provides its own storage."""
// return self.__base is None
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <initializer_list>
#include <limits>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "element_type.hpp"
#include "names.hpp"
#include "strides.hpp"
#include "util.hpp"
#include "uuid.hpp"
#include "uuid.hpp"
namespace ngraph
{
class Axes;
class Axis;
class FlattenedAxis;
class TensorDescription;
class Op;
using op_ptr = std::shared_ptr<Op>;
using tensor_description_ptr = std::shared_ptr<TensorDescription>;
using axes_key_t = size_t;
class slice
{
public:
slice(int64_t start = -1, int64_t stop = -1, int64_t step = 1);
size_t sliced_length(size_t length) const;
private:
size_t m_start;
size_t m_stop;
int64_t m_step;
};
//-----------------------------------------------------------------------------------------------
// default_dtype
//-----------------------------------------------------------------------------------------------
// def default_dtype(dtype=None):
// if dtype is None:
// dtype = np.dtype(np.float32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try:
// dtype = np.dtype(dtype)
// except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
//-----------------------------------------------------------------------------------------------
// default_int_dtype
//-----------------------------------------------------------------------------------------------
// def default_int_dtype(dtype=None):
// if dtype is None:
// dtype = np.dtype(np.int32)
// elif not isinstance(dtype, Flex) and not isinstance(dtype, np.dtype):
// try:
// dtype = np.dtype(dtype)
// except TypeError:
// raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype
//================================================================================================
// make_axis
// Returns a new Axis.
//
// Args:
// length (int, optional): Length of the axis.
// name (String, optional): Name of the axis.
// batch (bool, optional): This is a batch axis. Defaults to False.
// recurrent (bool, optional): This is a recurrent axis. Defaults to False.
// docstring (String, optional): A docstring for the axis.
//
// Returns:
// Axis: A new Axis.
//================================================================================================
Axis make_axis(size_t length,
const std::string& name = "",
bool batch = false,
bool recurrent = false);
//================================================================================================
// make_axes
// Makes an Axes object.
//
// Args:
// axes: A list of Axis.
//
// Returns:
// Axes: An Axes.
//================================================================================================
Axes make_axes(const std::vector<Axis>&);
//================================================================================================
// Axis
// An Axis labels a dimension of a tensor. The op-graph uses
// the identity of Axis objects to pair and specify dimensions in
// symbolic expressions. This system has several advantages over
// using the length and position of the axis as in other frameworks:
//
// 1) Convenience. The dimensions of tensors, which may be nested
// deep in a computation graph, can be specified without having to
// calculate their lengths.
//
// 2) Safety. Axis labels are analogous to types in general-purpose
// programming languages, allowing objects to interact only when
// they are permitted to do so in advance. In symbolic computation,
// this prevents interference between axes that happen to have the
// same lengths but are logically distinct, e.g. if the number of
// training examples and the number of input features are both 50.
//
// TODO: Please add to the list...
//
// Arguments:
// length: The length of the axis.
// batch: Whether the axis is a batch axis.
// recurrent: Whether the axis is a recurrent axis.
//================================================================================================
class Axis
{
public:
Axis& operator+(const Axis& rhs);
Axis& operator-(const Axis& rhs);
Axis();
Axis(size_t length, const std::string& new_name);
virtual ~Axis() {}
void named(const std::string& new_name);
//!-----------------------------------------------------------------------------------
//! is_flattened
//! Returns:
//! true if this is a flattened axis.
//!-----------------------------------------------------------------------------------
bool is_flattened() const;
//!-----------------------------------------------------------------------------------
//! is_batch
//! Tests if an axis is a batch axis.
//!
//! Returns:
//! bool: True if the axis is a batch axis.
//!-----------------------------------------------------------------------------------
bool is_batch() const;
//!-----------------------------------------------------------------------------------
//! is_recurrent
//! Tests if an axis is a recurrent axis.
//!
//! Returns:
//! bool: True if the axis is a recurrent axis.
//!-----------------------------------------------------------------------------------
bool is_recurrent() const;
//!-----------------------------------------------------------------------------------
//! is_channel
//! Tests if an axis is a channel axis.
//!
//! Returns:
//! bool: True if the axis is a channel axis.
//!-----------------------------------------------------------------------------------
bool is_channel() const;
//!-----------------------------------------------------------------------------------
//! length
//! Returns:
//! The length of the axis.
//!-----------------------------------------------------------------------------------
size_t length() const;
//!-----------------------------------------------------------------------------------
//! length
//!-----------------------------------------------------------------------------------
void length(size_t new_length);
//!-----------------------------------------------------------------------------------
//! axes
//!-----------------------------------------------------------------------------------
Axes axes() const;
//!-----------------------------------------------------------------------------------
//! operator<<
//!-----------------------------------------------------------------------------------
friend std::ostream& operator<<(std::ostream&, const Axis&);
virtual std::string to_string() const;
//!-----------------------------------------------------------------------------------
//! ???
//!-----------------------------------------------------------------------------------
// def __str__(self):
// return '{name}: {length}'.format(name=self.name, length=self.length)
//!-----------------------------------------------------------------------------------
//! operator==
//!-----------------------------------------------------------------------------------
bool operator==(const Axis&) const;
//!-----------------------------------------------------------------------------------
//! operator==
//!-----------------------------------------------------------------------------------
bool operator!=(const Axis&) const;
bool operator<(const Axis&) const;
//!-----------------------------------------------------------------------------------
//! hash
//!-----------------------------------------------------------------------------------
size_t hash() const;
std::string name;
uuid_type uuid;
size_t __length;
static size_t __name_counter;
};
//-----------------------------------------------------------------------------------------------
// _sliced_length
//-----------------------------------------------------------------------------------------------
// def _sliced_length(s, incoming_length):
// start, stop, step = s.indices(incoming_length)
// # max with 0 so we dont ever return a negative length. This
// # matches how python handles it internally. Raising an exception
// # might also be reasonable.
// if step == 1:
// return max(stop - start, 0)
// elif step == -1:
// return max(start - stop, 0)
// else:
// _validate_slice(s)
//-----------------------------------------------------------------------------------------------
// _validate_slice
//-----------------------------------------------------------------------------------------------
// def _validate_slice(s):
// if s.step not in (-1, 1, None):
// raise ValueError((
// 'SlicedAxis cant currently handle a step size other '
// 'than -1, 1 or None. Was given {step} in slice {slice}'
// ).format(
// step=s.step,
// slice=s,
// ))
//-----------------------------------------------------------------------------------------------
// slice_axis
// Slice an axis, return complete new axis
// TODO: deprecate this after the axis refactoring
//
// Arguments:
// axis: the axis to be sliced
// s: slice
//
// Returns:
// Axis instance, the new sliced axis
//-----------------------------------------------------------------------------------------------
// def slice_axis(axis, s):
Axis slice_axis(const Axis& axis, const slice& s);
//-----------------------------------------------------------------------------------------------
// duplicates
// Returns a list of Axis objects which have duplicate names in arr
//
// Arguments:
// arr: The iterable of Axis objects to check for duplicates in.
//
// Returns:
// list of Axis: duplicate Axis found in arr
//-----------------------------------------------------------------------------------------------
std::vector<std::string> duplicates(const std::vector<Axis>& ax);
//-----------------------------------------------------------------------------------------------
// with_args_as_axes
// A decorator to cast arguments to axes.
//
// Arguments:
// f: The function to be decorated.
//
// Returns:
// The decorated function.
//-----------------------------------------------------------------------------------------------
// def with_args_as_axes(f):
// @wraps(f)
// def wrapper(*args):
// """
// The decorated function. Performs the conversion
// to Axes.
// Arguments:
// *args: Arguments intended for the original function.
// Returns:
// Return value of the original function.
// """
// args = [Axes(arg) for arg in args]
// return f(*args)
// return wrapper
//================================================================================================
// Axes
// An Axes is a tuple of Axis objects used as a label for a tensor's
// dimensions.
//================================================================================================
class Axes
{
public:
std::vector<Axis> axes;
uuid_type uuid;
Axes();
Axes(const Axis&);
Axes(const std::vector<Axis>&);
Axes(const std::initializer_list<Axes>&);
//!-----------------------------------------------------------------------------------
//! full_lengths
//! Returns all information about the lengths of the axis objects
//! in this Axes in the form of a nested tuple. An element of the
//! outer tuple that is itself a tuple contains the restored lengths
//! of axes that have been flattened in this Axis object.
//!
//! Returns:
//! tuple: A nested tuple with the axis lengths.
//!-----------------------------------------------------------------------------------
std::vector<std::vector<size_t>> full_lengths() const;
//!-----------------------------------------------------------------------------------
//! names
//! Returns:
//! tuple: The names of the outer axes.
//!-----------------------------------------------------------------------------------
std::vector<std::string> names() const;
//!-----------------------------------------------------------------------------------
//! lengths
//! Returns:
//! tuple: The lengths of the outer axes.
//!-----------------------------------------------------------------------------------
std::vector<size_t> lengths() const;
//!-----------------------------------------------------------------------------------
//! batch_axes
//! Returns:
//! The tensor's batch Axis wrapped in an Axes object if there is one
//! on this tensor, otherwise returns None
//!-----------------------------------------------------------------------------------
Axes batch_axes();
//!-----------------------------------------------------------------------------------
//! batch_axis
//! Returns:
//! The tensor's batch Axis or None if there isn't one.
//!-----------------------------------------------------------------------------------
Axis batch_axis();
//!-----------------------------------------------------------------------------------
//! channel_axis
//! Returns:
//! The tensor's batch Axis or None if there isn't one.
//!-----------------------------------------------------------------------------------
Axes channel_axis();
//!-----------------------------------------------------------------------------------
//! spatial_axes
//! Returns:
//! The Axes subset that are not batch, recurrent, or channel axes.
//!-----------------------------------------------------------------------------------
Axes spatial_axes();
//!-----------------------------------------------------------------------------------
//! sample_axes
//! Returns:
//! The Axes subset that are not batch axes.
//!-----------------------------------------------------------------------------------
Axes sample_axes();
//!-----------------------------------------------------------------------------------
//! feature_axes
//! Returns:
//! The Axes subset that are not batch or recurrent axes.
//!-----------------------------------------------------------------------------------
Axes feature_axes();
//!-----------------------------------------------------------------------------------
//! recurrent_axis
//! Returns:
//! The tensor's recurrent Axis or None if there isn't one.
//!-----------------------------------------------------------------------------------
Axis recurrent_axis() const;
//!-----------------------------------------------------------------------------------
//! flatten
//! Produces flattened form of axes
//!
//! Args:
//! force: Add a FlattenedAxis even when the axis is already flat. This is needed
//! when the flatten is balanced by a later unflatten, as in dot.
//!
//! Returns:
//! A flat axis.
//!-----------------------------------------------------------------------------------
Axis flatten(bool force = false) const;
//!-----------------------------------------------------------------------------------
//! set_shape
//! Set shape of Axes
//!
//! Args:
//! shape: tuple or list of shapes, must be the same length as the axes
//!-----------------------------------------------------------------------------------
void set_shape(std::vector<size_t> shapes);
//!-----------------------------------------------------------------------------------
//! find_by_name
//!-----------------------------------------------------------------------------------
Axes find_by_name(const std::string&);
decltype(axes)::iterator begin() { return axes.begin(); }
decltype(axes)::iterator end() { return axes.end(); }
decltype(axes)::const_iterator begin() const { return axes.begin(); }
decltype(axes)::const_iterator end() const { return axes.end(); }
// def __iter__(self):
// return self._axes.__iter__()
// def __len__(self):
// return len(self._axes)
const Axis& operator[](size_t) const;
const Axis& operator[](const slice&) const;
// def __getitem__(self, item):
// if isinstance(item, slice):
// return Axes(self._axes.__getitem__(item))
// else:
// return self._axes.__getitem__(item)
// def __getslice__(self, i, j):
// return self.__getitem__(slice(i, j))
//!-----------------------------------------------------------------------------------
//! operator+
//! Returns list concatenated axes. Throws exception when there are Axis
//! duplication.
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! current axes concatenated with the other axes
//!-----------------------------------------------------------------------------------
Axes operator+(const Axes&);
//!-----------------------------------------------------------------------------------
//! operator-
//! Returns ordered set difference of axes.
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! The ordered set difference of axes
//!-----------------------------------------------------------------------------------
Axes operator-(const Axes&);
//!-----------------------------------------------------------------------------------
//! operator|
//! Returns ordered set union of axes.
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! The ordered set union of axes
//!-----------------------------------------------------------------------------------
Axes operator|(const Axes&);
//!-----------------------------------------------------------------------------------
//! operator&
//! Returns ordered set intersection of axes.
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! The ordered set intersection of axes
//!-----------------------------------------------------------------------------------
Axes operator&(const Axes&);
//!-----------------------------------------------------------------------------------
//! operator==
//! True if each ``Axis`` are matching and in same order (list comparison)
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! bool, True if each ``Axis`` are matching and in same order
//!
//! See Also ``is_equal_set`` if you want the comparison to ignore the Axes order
//!-----------------------------------------------------------------------------------
bool operator==(const Axes&) const;
//!-----------------------------------------------------------------------------------
//! operator!=
//! The opposite of __eq__, True if not all ``Axis`` are matching or in
//! different order (list comparison)
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! bool, True if not all ``Axis`` are matching or in different order
//!-----------------------------------------------------------------------------------
bool operator!=(const Axes&) const;
bool operator<(const Axes&) const;
//!-----------------------------------------------------------------------------------
//! axes
//! Axes considered nonzero if axes are nonzero.
//!-----------------------------------------------------------------------------------
// def __nonzero__(self):
// """ """
// return bool(self._axes)
// //!-----------------------------------------------------------------------------------
// //! hash
// //!-----------------------------------------------------------------------------------
// size_t hash() const;
//!-----------------------------------------------------------------------------------
//! is_sub_set
//! Returns true if other is subset of self, i.e. <=
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! bool, true if other is subset of self
//!-----------------------------------------------------------------------------------
bool is_sub_set(const Axes& other) const;
//!-----------------------------------------------------------------------------------
//! is_super_set
//! Returns true if other is superset of self, i.e. >=
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! bool, true if other is superset of self
//!-----------------------------------------------------------------------------------
bool is_super_set(const Axes& other) const;
//!-----------------------------------------------------------------------------------
//! is_equal_set
//! Returns true if other has the same set of Axis names as self
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! bool, true if other has the same set of Axis names as self
//!-----------------------------------------------------------------------------------
bool is_equal_set(const Axes& other) const;
//!-----------------------------------------------------------------------------------
//! is_not_equal_set
//! Returns true if other does not the same set of Axis names as self
//!
//! Arguments:
//! other: the right-hand side operator axes
//!
//! Returns:
//! bool, true if other does not has the same set of Axis names as self
//!-----------------------------------------------------------------------------------
bool is_not_equal_set(const Axes& other) const;
//!-----------------------------------------------------------------------------------
//! T
//!-----------------------------------------------------------------------------------
// def T(self):
// return Axes(axis.T for axis in self)
//!-----------------------------------------------------------------------------------
//! index
//! Returns the index of an axis
//!
//! Arguments:
//! axis: The axis to search for.
//!
//! Returns:
//! The index.
//!-----------------------------------------------------------------------------------
size_t index(const Axis&) const;
// @with_args_as_axes
//!-----------------------------------------------------------------------------------
//! assert_valid_broadcast
//! Checks whether axes can be broadcasted to new_axes. We require
//! that the components of axes be laid out in the same order in new_axes.
//!
//! Axes:
//! axes: The original axes.
//! new_axes: The broadcasted axes.
//!
//! Returns:
//! True if axes can be broadcasted to new_axes, False otherwise.
//!-----------------------------------------------------------------------------------
static void assert_valid_broadcast(const Axes& axes, const Axes& new_axes);
// @with_args_as_axes
//!-----------------------------------------------------------------------------------
//! is_valid_flatten_or_unflatten
//! Checks whether we can flatten OR unflatten from src_axes to dst_axes.
//!
//! The requirements are that the components of axes should all be
//! present in new_axes and that they should be laid out in the same
//! order. This check is symmetric.
//!-----------------------------------------------------------------------------------
static bool is_valid_flatten_or_unflatten(const Axes& src_axes, const Axes& dst_axes);
// @with_args_as_axes
//!-----------------------------------------------------------------------------------
//! assert_valid_flatten
//! Checks whther axes can safely be flattened to produce new_axes.
//! The requirements are that the components of axes should all be
//! present in new_axes and that they should be laid out in the same
//! order.
//!
//! Arguments:
//! unflattend_axes: The original axes.
//! flattened_axes: The flattened axes.
//!
//! Returns:
//! True if axes can be safely flattened to new_axes, False otherwise.
//!-----------------------------------------------------------------------------------
static void assert_valid_flatten(const Axes& unflattend_axes, const Axes& flattened_axes);
// @with_args_as_axes
//!-----------------------------------------------------------------------------------
//! assert_valid_unflatten
//! Checks whether axes can safely be unflattened to produce new_axes.
//! The requirements are that the components of axes should all be
//! present in new_axes and that they should be laid out in the same
//! order.
//!
//! Arguments:
//! flattened_axes: The original axes.
//! unflattend_axes: The unflattened axes.
//!
//! Returns:
//! True if axes can be safely unflattened to new_axes, False otherwise.
//!-----------------------------------------------------------------------------------
static void assert_valid_unflatten(const Axes& flattened_axes, const Axes& unflattend_axes);
//!-----------------------------------------------------------------------------------
//! size
//! TODO: delete this method, the size should come from the tensor
//!-----------------------------------------------------------------------------------
size_t size() const;
//!-----------------------------------------------------------------------------------
//! operator<<
//!-----------------------------------------------------------------------------------
friend std::ostream& operator<<(std::ostream&, const Axes&);
//!-----------------------------------------------------------------------------------
//! as_nested_list
//! Converts Axes to a list of axes with flattened axes expressed as nested lists
//!
//! Returns:
//! Nested list of Axis objects
//!-----------------------------------------------------------------------------------
static std::vector<Axis> as_nested_list(const Axes&);
//!-----------------------------------------------------------------------------------
//! as_flattened_list
//! Converts Axes to a list of axes with flattened axes expanded recursively.
//!
//! Returns:
//! List of Axis objects
//!-----------------------------------------------------------------------------------
static std::vector<Axis> as_flattened_list(const Axes&);
std::vector<Axis> convert(const Axes& ax);
std::vector<Axis> convert(const std::vector<Axes>& ax);
private:
void check_duplicates();
};
//================================================================================================
// DuplicateAxisNames
//================================================================================================
// class DuplicateAxisNames(ValueError):
// def __init__(self, message, duplicate_axis_names):
// super(DuplicateAxisNames, self).__init__(message)
// self.duplicate_axis_names = duplicate_axis_names
//================================================================================================
// IncompatibleAxesError
//================================================================================================
// class IncompatibleAxesError(ValueError):
// pass
//================================================================================================
// UnmatchedAxesError
//================================================================================================
// class UnmatchedAxesError(IncompatibleAxesError):
// pass
//================================================================================================
// AxesMap
// AxesMap provides a way to define a axis name mapping: {Axis.name: Axis.name} and
// then apply this mapping to an Axes and get new Axes out.
//
// Right now AxesMap is implemented as immutible because I didn't want to deal with
// enforcing _assert_valid_axes_map on every method which mutates a dict and I didn't
// need a mutable datastructure anyway. Feel free to make it mutable and add in
// invariant enforcement.
//================================================================================================
class AxesMap : public std::map<std::string, std::string>
{
public:
AxesMap(const std::pair<std::string, std::string>&);
AxesMap(std::initializer_list<std::pair<std::string, std::string>>);
//--------------------------------------------------------------------------------------------
// Returns:
// Axes with lengths from axes and names which have been passed through axes_map
//--------------------------------------------------------------------------------------------
Axes map_axes(const Axes&) const;
//--------------------------------------------------------------------------------------------
// Given a map from {old_axes_name: new_axes_name} and an old_axis map the
// old_axis into the new_axes.
//--------------------------------------------------------------------------------------------
Axis map_axis(const Axis& old_axis) const;
private:
std::map<std::string, std::set<std::string>> duplicate_axis_names();
void assert_valid_axes_map();
public:
// def invert(self):
// return {v: k for k, v in self.items()}
};
//-----------------------------------------------------------------------------------------------
// _reduce_nested
// Reduces a nested sequence by applying a function to each
// of its elements and returns an aggregation.
//
// Arguments:
// elem: The object to be reduced, either a sequence
// or a singleton.
// agg: A variable holding information collected
// as the sequence is collapsed.
// func: A function to augment the aggregate by processing
// a singleton. Should have the form func(agg, elem) -> agg
//
// Returns:
// agg: The final aggregate returned by the function.
//-----------------------------------------------------------------------------------------------
// def _reduce_nested(elem, agg, func):
// if isinstance(elem, collections.Iterable):
// for sub in elem:
// agg = _reduce_nested(sub, agg, func)
// return agg
// else:
// return func(agg, elem)
//================================================================================================
// FlattenedAxis
// A FlattenedAxis has length which is the product of the lengths of all
// Axis in the axes. The original Axes object is stored so that we can later
// unflatten this Axis back to its original component Axis.
//
// Notes: since we allows Axis to have duplicated names globally, NameableValue
// is not used here.
//================================================================================================
class FlattenedAxis : public Axis
{
public:
FlattenedAxis(const std::vector<Axis>& list, const std::string& new_name = "");
virtual ~FlattenedAxis() {}
//--------------------------------------------------------------------------------------------
// Returns:
// True is this is a FlattendAxis.
//--------------------------------------------------------------------------------------------
bool is_flattened() const { return true; }
//--------------------------------------------------------------------------------------------
// Returns:
// Whether this axes contains no collapsed axes.
//--------------------------------------------------------------------------------------------
bool empty() const { return axes.size() == 0; }
//--------------------------------------------------------------------------------------------
// Returns:
// Whether this axes contains exactly one collapsed axes.
//--------------------------------------------------------------------------------------------
bool single() const { return axes.size() == 0; }
bool operator==(const Axis& other) const;
// def __hash__(self):
// return hash(self.axes)
friend std::ostream& operator<<(std::ostream&, const FlattenedAxis&);
virtual std::string to_string() const override;
// def __repr__(self):
// return 'FlattenedAxis(%s)' % ', '.join(repr(axis) for axis in self.axes)
std::vector<Axis> axes;
};
//-----------------------------------------------------------------------------------------------
// default_dtype
// Reduces a nested tuple describing the strides of a tensor
// into a tuple giving the stride of each of its dimensions.
//
// Arguments:
// strides: The nested tuple.
//
// Returns:
// strides: The tuple of strides.
//-----------------------------------------------------------------------------------------------
// def reduce_strides(strides):
// return tuple(int(_reduce_nested(elem, float('inf'), min))
// for elem in strides)
//-----------------------------------------------------------------------------------------------
// _make_stride
// Generates a nested tuple that provides the striding information
// for an occurrence of axis. If the axis is a FlattenedAxis, the
// stride will be a tuple containing the strides of each collapsed
// axis. Otherwise, the stride will be an integer.
//
// Arguments:
// inner_size: The total size of all dimensions smaller than this
// axis, i.e. all axes to the right of this one when they are
// laid out in c-contiguous order.
// axis: The axis for which we are generating a stride.
// fsz: A nested tuple supplying the sizes of each dimension collapsed
// into the axis. The size may be larger than the length of the axis.
//
// Returns:
// inner_size: The total size of this axis and all smaller dimensions.
// stride: The stride given to the axis.
//-----------------------------------------------------------------------------------------------
// def _make_stride(inner_size, axis, fsz):
// if axis.is_flattened:
// return _make_strides(inner_size, axis.axes, fsz)
// else:
// stride = inner_size
// inner_size *= fsz
// return inner_size, stride
//-----------------------------------------------------------------------------------------------
// _make_strides
// Generates a tuple of strides for a set of axes. See _make_stride
// for a description of the stride given to each axis.
//
// Arguments:
// inner_size: The total size of all dimensions smaller than
// the axes.
// axes: The axes for which we are generating strides.
// full_sizes: The size of each axis.
//
// Returns:
// inner_size: The total size of these axes and all smaller dimensions.
// strides: The strides generated for the axes.
//-----------------------------------------------------------------------------------------------
// def _make_strides(inner_size, axes, full_sizes):
// full_strides = []
// for axis, fsz in reversed(list(zip(axes, full_sizes))):
// inner_size, stride = _make_stride(inner_size, axis, fsz)
// full_strides.append(stride)
// return inner_size, tuple(reversed(full_strides))
//================================================================================================
// TensorDescription
// Description of a tensor that will be allocated in hardware.
//
// Names the tensor's dimensions with axes and holds pointers to the
// buffer allocated by the analysis and the backend tensor value
// (e.g. a cpu or gpu tensor).
//
// Arguments:
// axes: Axes of the tensor.
// base: If a view, the viewed tensor's description.
// dtype: The type of the tensor.
// full_strides: The strides of each axis.
// full_sizes: The allocated size of each axis (may be larger than the axis).
// offset: An offset into the viewed tensor.
// next_tensor_decription: In a reshape, tensor description of reshaped tensor.
// is_persistent: The tensor should be persistent, i.e. survive from computation to
// computation.
// is_input: The device tensor can be written from the host.
// **kwargs: Additional args for related classes.
//================================================================================================
class TensorDescription : public NameableValue
{
public:
//!-----------------------------------------------------------------------------------
//! constructor
//!-----------------------------------------------------------------------------------
TensorDescription(op_ptr op = nullptr,
const Axes& _axes = Axes(),
tensor_description_ptr base = nullptr,
// layout,
ElementType et = element_type_float,
ngraph::tensor_stride full_strides = ngraph::tensor_stride(),
ngraph::tensor_size full_sizes = ngraph::tensor_size(),
size_t offset = 0,
TensorDescription* next_tensor_decription = nullptr,
const std::string& name = "",
bool is_persistent = false,
bool is_input = false,
bool is_placeholder = false);
// std::string name() const;
std::vector<size_t> shape() const;
tensor_description_ptr base() const;
ElementType element_type() const;
size_t tensor_size() const;
//!-----------------------------------------------------------------------------------
//! operator<<
//!-----------------------------------------------------------------------------------
// def __repr__(self):
// return self.base.name
//!-----------------------------------------------------------------------------------
//! is_persistent
//! Returns: True if persists from computation to computation.
//!-----------------------------------------------------------------------------------
bool is_persistent() const;
//!-----------------------------------------------------------------------------------
//! is_input
//! Returns: True if writable from host.
//!-----------------------------------------------------------------------------------
bool is_input() const;
//!-----------------------------------------------------------------------------------
//! is_placeholder
//! Returns: True if a placeholder; a place to attach a tensor.
//!-----------------------------------------------------------------------------------
bool is_placeholder() const;
//!-----------------------------------------------------------------------------------
//! parameter_key
//! Returns: A tuple that can be used to tell if two views of a tensor are equivalent.
//!-----------------------------------------------------------------------------------
// @property
// def parameter_key(self):
// return (self.shape, self.dtype, self.offset, self.strides, self.layout)
size_t parameter_key() const;
//!-----------------------------------------------------------------------------------
//! axes_key
//!-----------------------------------------------------------------------------------
// @property
// def axes_key(self):
// return (self.axes, self.shape, self.dtype, self.offset, self.strides, self.layout)
axes_key_t axes_key() const;
//!-----------------------------------------------------------------------------------
//! flatten
//! Flattens a tensor description to give it the Axes in new_axes.
//! See Axes.assert_valid_flatten for a description of permitted values of new_axes.
//!
//! Arguments:
//! new_axes: The Axes of the flattened tensor description.
//!
//! Returns:
//! The reshaped tensor description.
//!-----------------------------------------------------------------------------------
// def flatten(self, new_axes):
// new_axes = Axes(new_axes)
// Axes.assert_valid_flatten(self.axes, new_axes)
// new_strides = []
// new_sizes = []
// idx = 0
// for new_axis in new_axes:
// if new_axis == self.axes[idx]:
// new_stride = self.full_strides[idx]
// new_size = self.full_sizes[idx]
// idx += 1
// else:
// l = len(new_axis.axes)
// new_stride = self.full_strides[idx:idx + l]
// new_size = self.full_sizes[idx:idx + l]
// idx += l
// new_strides.append(new_stride)
// new_sizes.append(new_size)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=new_strides,
// full_sizes=new_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rFlatten',
// )
//!-----------------------------------------------------------------------------------
//! unflatten
//! Unflattens a tensor description to give it the Axes in new_axes.
//! See Axes.assert_valid_unflatten for a description of the permitted values of
//! new_axes
//!
//! Arguments:
//! new_axes: The Axes of the unflattened TensorDescription.
//!
//! Returns:
//! The unflattened tensor description.
//!-----------------------------------------------------------------------------------
// def unflatten(self, new_axes):
// def find_axis_stride_and_length(axis):
// """
// Find the stride and length for an axis.
// Start at the current tensor description and then work back
// through reshapings of it looking for a mention of the axis
// that can be used to determine the storage stride and offset.
// Args:
// axis: The axis.
// Returns:
// stride, length of axis
// """
// td = self
// while td is not None:
// for idx, a in enumerate(td.axes):
// # Try to find a match for axis in this td
// full_strides = td.full_strides[idx]
// full_sizes = td.full_sizes[idx]
// if a == axis:
// return full_strides, full_sizes
// if a.is_flattened:
// # Can be embedded ina a flattened axis description
// if not isinstance(full_strides, tuple):
// # An axis cast can lose striding info, so need to
// # recreate it from the axis lengths. Being flattened
// # implies C-contiguous
// stride = full_strides
// full_strides = []
// full_sizes = []
// for s in reversed(a.axes):
// full_sizes.insert(0, s.length)
// full_strides.insert(0, stride)
// stride = stride * s.length
// # Now search for axis in the flattened axis
// for sub_idx, b in enumerate(a.axes):
// if b == axis:
// return full_strides[sub_idx], full_sizes[sub_idx]
// # Move on to the next tensor description in the reshaping chain
// td = td.next_tensor_description
// # Sometimes we just don't have enough information.
// raise ValueError()
// new_axes = Axes(new_axes)
// Axes.assert_valid_unflatten(self.axes, new_axes)
// new_strides = []
// new_sizes = []
// for new_axis in new_axes:
// stride, size = find_axis_stride_and_length(new_axis)
// new_strides.append(stride)
// new_sizes.append(size)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=new_strides,
// full_sizes=new_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rUnflatten',
// )
//!-----------------------------------------------------------------------------------
//! transpose
//! Reverses the axes of the tensor description.
//!
//! Retuns:
//! A tensor description with the axes reversed.
//!-----------------------------------------------------------------------------------
// def transpose(self):
// new_axes = reversed(self.axes)
// full_sizes = reversed(self.full_sizes)
// full_strides = reversed(self.full_strides)
// return TensorDescription(
// Axes(new_axes),
// base=self.base,
// dtype=self.dtype,
// full_strides=tuple(full_strides),
// full_sizes=tuple(full_sizes),
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rTranspose',
// )
//!-----------------------------------------------------------------------------------
//! clone
//! Creates a copy of this tensor description
//!
//! Retuns:
//! A copy of this tensor description
//!-----------------------------------------------------------------------------------
// def clone(self):
// return TensorDescription(
// self.axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=self.full_strides,
// full_sizes=self.full_sizes,
// offset=self.offset,
// next_tensor_description=self.next_tensor_description,
// name=self.name + 'cView',
// )
//!-----------------------------------------------------------------------------------
//! broadcast
//! Adds axes to a tensor description to give it a new shape.
//! See Axes.assert_valid_broadcast for a description of the permitted
//! transformations.
//!
//! Arguments:
//! new_axes: The axes of the broadcasted tensor description.
//!
//! Returns:
//! TensorDescription: The broadcasted tensor description.
//!-----------------------------------------------------------------------------------
TensorDescription broadcast(const Axes& new_axes);
//!-----------------------------------------------------------------------------------
//! reorder
//! Shuffles axes of a tensor to give it a new shape. The axes of
//! this tensor description and new_axes must have the same elements.
//!
//! Arguments:
//! new_axes: The axes of the reordered tensor.
//!
//! Returns:
//! TensorDescription: The reordered tensor description.
//!-----------------------------------------------------------------------------------
// def reorder(self, new_axes):
// if not self.axes.is_equal_set(new_axes):
// raise ValueError((
// "Reorder can't change which axes are available, only the "
// "order. {} and {} are different sets, not just order."
// ).format(self, new_axes))
// return self.reorder_and_broadcast(new_axes)
//!-----------------------------------------------------------------------------------
//! reorder_and_broadcast
//! Adds or shuffles axes to give a tensor description a new shape.
//! This function is used to implement broadcast and reorder.
//!
//! Arguments:
//! new_axes: The axes of the broadcasted or reordered tensor.
//!
//! Returns:
//! TensorDescription: A description of the tensor after the
//! transformation.
//!-----------------------------------------------------------------------------------
TensorDescription reorder_and_broadcast(const Axes& new_axes);
// def reorder_and_broadcast(self, new_axes):
// def zero_in_shape(tup):
// if isinstance(tup, collections.Iterable):
// return tuple(
// zero_in_shape(t) for t in tup
// )
// else:
// return 0
// new_axes = Axes(new_axes)
// new_strides = []
// new_sizes = []
// for axis in new_axes:
// if axis in self.axes:
// idx = self.axes.index(axis)
// new_strides.append(self.full_strides[idx])
// new_sizes.append(self.full_sizes[idx])
// elif axis.is_flattened:
// lengths = axis.axes.full_lengths
// new_strides.append(zero_in_shape(lengths))
// new_sizes.append(lengths)
// else:
// new_strides.append(0)
// new_sizes.append(axis.length)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=new_strides,
// full_sizes=new_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rReorderBroadcast',
// )
//!-----------------------------------------------------------------------------------
//! cast
//! Return a tensor desciption for a view of the tensor.
//!
//! Arguments:
//! new_axes: The axes for the view.
//!
//! Returns:
//! The tensor description.
//!-----------------------------------------------------------------------------------
// def cast(self, new_axes):
// full_strides = self.full_strides
// full_sizes = self.full_sizes
// if self.ndim == 0:
// full_strides = (0,) * len(new_axes)
// full_sizes = new_axes.full_lengths
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=full_strides,
// full_sizes=full_sizes,
// offset=self.offset,
// next_tensor_description=self,
// name=self.name + 'rCast',
// )
//!-----------------------------------------------------------------------------------
//! slice
//! Return a tensor description for a slice view of this tensor.
//!
//! Arguments:
//! slices: The slices to take from the tensor, each of which is
//! either an integer or a python slice. If the input has too few
//! axes for the tensor, we assume that the entire axis should be
//! taken for dimensions towards the end of the tensor.
//! new_axes: the axes to use as labels for the sliced tensor.
//!
//! Returns:
//! The tensor description for the slice.
//!-----------------------------------------------------------------------------------
// def slice(self, slices, new_axes):
// slices = list(slices)
// while len(slices) < self.ndim:
// slices.append(slice(None))
// offset = self.offset
// full_strides = []
// full_sizes = []
// new_index = 0
// # check new_axes for the correct length
// num_dimensions_out = len([s for s in slices if isinstance(s, slice)])
// if len(new_axes) != num_dimensions_out:
// raise ValueError((
// 'in a slice operation, the number of axes passed in to '
// 'new_axes ({num_new_axes}) must be the same as the number of '
// 'slice objects in slices ({num_slices}).'
// ).format(
// num_new_axes=len(new_axes),
// num_slices=num_dimensions_out,
// ))
// for s, axis, stride, size in zip(slices, self.axes, self.strides, self.sizes):
// if isinstance(s, slice):
// # only increment new_axis when the input slice is a slice and
// # not a integer
// new_axis = new_axes[new_index]
// new_index += 1
// # ensure slice is of the kind we support
// _validate_slice(s)
// # ensure new_axis has the correct length
// new_axis.length = _sliced_length(s, axis.length)
// start, stop, step = s.indices(axis.length)
// full_strides.append(stride * step)
// full_sizes.append(size)
// idx = start
// else:
// # this is a simple integer slice, ex: y = x[1]
// idx = s
// # TODO: write a test that fails if abs() is removed
// offset += idx * abs(stride)
// return TensorDescription(
// new_axes,
// base=self.base,
// dtype=self.dtype,
// full_strides=tuple(full_strides),
// full_sizes=tuple(full_sizes),
// offset=offset,
// next_tensor_description=self,
// name=self.name + "rSlice",
// )
//!-----------------------------------------------------------------------------------
//! shape
//! Returns: The shape of the tensor.
//!-----------------------------------------------------------------------------------
// @property
// def shape(self):
// return self.axes.lengths
//!-----------------------------------------------------------------------------------
//! strides
//! The strides of the tensor.
//!-----------------------------------------------------------------------------------
// @property
// def strides(self):
// return reduce_strides(self.full_strides)
//!-----------------------------------------------------------------------------------
//! sizes
//! The allocated sizes for each axis.
//!-----------------------------------------------------------------------------------
// @property
// def sizes(self):
// return tuple(_reduce_nested(_, 1, operator.mul)
// for _ in self.full_sizes)
//!-----------------------------------------------------------------------------------
//! tensor_size
//!-----------------------------------------------------------------------------------
// @property
// def tensor_size(self):
// result = self.dtype.itemsize
// for s in self.sizes:
// result = result * s
// return result
//!-----------------------------------------------------------------------------------
//! c_contiguous
//! Returns:
//! True if the tensor's strides are row-major contiguous.
//!-----------------------------------------------------------------------------------
// @property
// def c_contiguous(self):
// s = self.dtype.itemsize
// cstrides = []
// for _ in reversed(self.shape):
// cstrides.insert(0, s)
// s = s * _
// return tuple(cstrides) == self.strides
//!-----------------------------------------------------------------------------------
//! broadcast_contiguous
//! Returns:
//! True if tensor's strides are contiguous or broadcasted
//!-----------------------------------------------------------------------------------
// @property
// def broadcast_contiguous(self):
// if self.shape == ():
// return True
// broadcast_axes = np.where(np.equal(self.strides, 0))[0]
// aug_shape = list(self.shape)
// for bcast_axis in broadcast_axes:
// aug_shape[bcast_axis] = 1
// s = self.dtype.itemsize
// cstrides = []
// for _ in reversed(aug_shape):
// cstrides.insert(0, s)
// s = s * _
// for bcast_axis in broadcast_axes:
// cstrides[bcast_axis] = 0
// return tuple(cstrides) == self.strides
//!-----------------------------------------------------------------------------------
//! base
//! The viewed tensor description or None if not a view.
//!-----------------------------------------------------------------------------------
// @property
// def base(self):
// return self.__base or self
//!-----------------------------------------------------------------------------------
//! layout
//! The layout of the underlying storage.
//!-----------------------------------------------------------------------------------
// @property
// def layout(self):
// return self.__layout
//!-----------------------------------------------------------------------------------
//! layout
//! Sets the backend-specific memory layout to be used by the tensor.
//!
//! Arguments:
//! value: the layout to use
//!
//! Returns:
//!-----------------------------------------------------------------------------------
// @layout.setter
// def layout(self, value):
// self.__layout = value
//!-----------------------------------------------------------------------------------
//! register
//!-----------------------------------------------------------------------------------
// @property
// def register(self):
// return self.base.__register
//!-----------------------------------------------------------------------------------
//! register
//!-----------------------------------------------------------------------------------
// @register.setter
// def register(self, value):
// self.base.__register = value
//!-----------------------------------------------------------------------------------
//! is_base
//! This tensor provides its own storage.
//!-----------------------------------------------------------------------------------
// def is_base(self):
// return self.__base is None
op_ptr op;
Axes axes;
bool __is_persistent;
bool __is_input;
bool __is_placeholder;
tensor_description_ptr __base;
// __layout = layout
// __value = None
// __buffer = None
// __register = None
ElementType dtype;
size_t offset;
size_t ndim;
ngraph::tensor_size full_sizes;
ngraph::tensor_stride full_strides;
tensor_description_ptr next_tensor_description;
};
} // end of namespace ngraph
namespace std
{
template <>
struct std::hash<ngraph::Axis>
{
size_t operator()(const ngraph::Axis& axis) const
{
std::hash<std::string> h1;
std::hash<size_t> h2;
return ngraph::hash_combine({h1(axis.name), h2(axis.length())});
}
};
}
namespace std
{
template <>
struct std::hash<ngraph::Axes>
{
size_t operator()(const ngraph::Axes& axes) const
{
std::hash<ngraph::Axis> h1;
std::vector<size_t> hashes;
for (auto axis : axes)
{
hashes.push_back(h1(axis));
}
return ngraph::hash_combine(hashes);
}
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cmath>
#include <exception>
#include <memory>
#include <sstream>
#include "exop.hpp"
#include "op_graph.hpp"
#include "util.hpp"
using namespace ngraph;
//================================================================================================
// InputDecl
//================================================================================================
InputDecl::InputDecl(const ExOp& _exop,
size_t _pos,
tensor_description_ptr _tensor_description,
OutputDecl* _value)
: exop{_exop}
, pos{_pos}
, tensor_description{_tensor_description}
, read_view{nullptr}
, m_value{_value}
{
}
TensorDecl& InputDecl::tensor_decl()
{
return read_view->tensor_decl;
}
OutputDecl* InputDecl::value()
{
// Returns: The argument supplying this value.
return m_value;
}
const OutputDecl* InputDecl::value() const
{
// Returns: The argument supplying this value.
return m_value;
}
void InputDecl::value(OutputDecl* value)
{
// Changes the value assigned to this argument, updating value users.
// Args:
// value: The new value for this argument.
if (m_value != nullptr)
{
remove_from(m_value->value_users, this);
remove_from(read_view->readers, this);
}
if (m_value != nullptr)
{
tensor_description = value->tensor_description;
}
m_value = value;
if (value != nullptr)
{
value->value_users.insert(this);
read_view = value->write_view()->get_tensor_view(tensor_description, this);
}
}
std::ostream& ngraph::operator<<(std::ostream& out, const InputDecl& obj)
{
out << "Arg(" << obj.exop.name() << obj.pos << ")";
return out;
}
//================================================================================================
// OutputDecl
//================================================================================================
OutputDecl::OutputDecl(const ExOp& _exop,
size_t _pos,
tensor_decl_ptr _tensor_decl,
tensor_description_ptr _tensor_description)
: exop{_exop}
, pos{_pos}
, tensor_description{_tensor_description}
, __tensor{_tensor_decl}
, __write_view{nullptr}
, value_users{}
{
}
tensor_decl_ptr OutputDecl::tensor_decl()
{
return __tensor;
}
void OutputDecl::tensor_decl(tensor_decl_ptr tensor_decl)
{
if (__tensor == tensor_decl)
{
return;
}
if (__tensor != nullptr)
{
tensor_decl->merge_flags(*__tensor);
}
__tensor = tensor_decl;
write_view(tensor_decl->get_tensor_view(tensor_description, this));
}
tensor_view_decl_ptr OutputDecl::write_view()
{
return __write_view;
}
void OutputDecl::write_view(tensor_view_decl_ptr view)
{
if (view == nullptr && value_users.size() > 0)
{
throw std::runtime_error("Cannot deallocate a view that is in use");
}
__write_view = view;
view->value = this;
if (view != nullptr)
{
for (InputDecl* arg : value_users)
{
arg->tensor_description = tensor_description;
arg->read_view = view->get_tensor_view(arg->tensor_description, arg);
}
}
}
std::ostream& ngraph::operator<<(std::ostream& out, const OutputDecl& obj)
{
out << "Val(" << obj.exop.name() << ":" << obj.pos << ")";
return out;
}
//================================================================================================
// ExOp
//================================================================================================
ExOp::ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value)
: ExecutionGraphElt{cgraph.execution_graph}
, computation_graph{cgraph}
, tensor_decl{nullptr}
, tensor_view{nullptr}
, ref_ops{}
, op{_op}
, liveness_live_list{}
, liveness_free_list{}
, liveness_new_list{}
, args{}
, write_args{}
, values{}
{
// moved from ExecuteExOp
if (op != nullptr)
{
computation_graph.ops[op] = this;
add_ref_op(op);
}
// endmoved from ExecuteExOp
// TODO shim Op needs to be fixed
// for (input_decl_ptr arg : op->args)
// {
// arg = arg.effective_tensor_op();
// exop_ptr exop = computation_graph.get_exop(arg);
// output_decl_ptr value = exop->values[0];
// add_arg(value);
// }
if (create_value && op->is_tensor_op())
{
tensor_description_ptr tdesc = op->tensor_description();
tensor_decl_ptr tdecl = computation_graph.get_tensor_decl(op);
add_value(tdecl, tdesc);
}
}
std::ostream& ngraph::operator<<(std::ostream& out, const ExOp& obj)
{
out << obj.op->name();
std::vector<std::string> args;
for (const InputDecl& id : obj.args)
{
std::stringstream ss;
ss << id.value();
args.push_back(ss.str());
}
out << "\n\targs: " << join(args, ", ");
out << "\n\tvalues: " << join(obj.values, ", ");
out << "\n\tlive: " << join(obj.liveness_live_list, ", ");
out << "\n\tnew: " << join(obj.liveness_new_list, ", ");
out << "\n\tfree: " << join(obj.liveness_free_list, ", ");
return out;
}
exop_ptr ExOp::literal_scalar_exop(scalar_t scalar, ComputationDecl& computation_graph)
{
op_ptr op = std::make_shared<LiteralScalarOp>(scalar);
exop_ptr exop = std::make_shared<ExOp>(computation_graph, op);
exop->values[0].tensor_decl()->is_compile_only = true;
return exop;
}
InputDecl& ExOp::add_arg(OutputDecl& value, tensor_description_ptr tensor_description)
{
args.emplace_back(*this, args.size(), tensor_description, &value);
return args.back();
}
// InputDecl& ExOp::add_write_arg(OutputDecl& value, tensor_description_ptr tensor_description)
// {
// input_decl_ptr arg = std::make_shared<InputDecl>(this,
// args.size(),
// value,
// tensor_description);
// write_args.push_back(arg);
// return write_args.back();
// }
OutputDecl& ExOp::add_value(tensor_decl_ptr tdecl, tensor_description_ptr tensor_description)
{
if (tensor_description == nullptr)
{
tensor_description = tdecl->tensor_description_base;
}
values.emplace_back(*this, values.size(), tdecl, tensor_description);
return values.back();
}
// // void ExOp::take_value(output_decl_ptr value)
// // {
// // // TODO: this is not going to work ExOpNode is not an ExOp and it is not
// // // a shared pointer. exop is shared_ptr<ExOp> so we can't get there from
// // // here
// // value->exop = this;
// // value->pos = values.size();
// // }
op_ptr ExOp::get_op()
{
return op;
}
void ExOp::set_op(op_ptr _op)
{
if (_op == nullptr)
{
if (op == nullptr)
{
throw std::invalid_argument("Cannot set op to None.");
}
return;
}
if (_op->is_tensor_op())
{
op_ptr tensor_op = _op->tensor();
if (_op != tensor_op && tensor_op->is_state_op() == false)
{
add_ref_op(_op);
_op = tensor_op;
}
}
op = _op;
if (op != nullptr)
{
add_ref_op(op);
}
}
void ExOp::add_ref_op(op_ptr _op)
{
// Add another op that references this exop.
// Args:
// _op: The computation graph op freferencing this exop.
ref_ops.push_back(_op);
computation_graph.ops[_op] = this;
}
size_t ExOp::memory_usage()
{
// Get the memory usage of this op which is the sum of the sizes of all
// off the live tensors.
// Arguments:
// None
// Returns:
// Memory usage in bytes
size_t size = 0;
for (tensor_decl_ptr node : liveness_live_list)
{
size += node->size;
}
return size;
}
size_t ExOp::memory_footprint()
{
// Get the total memory footprint of this op. The footprint hightest memory
// address used by the tensors in this op
// Arguments:
// None
// Returns:
// Memory footprint in bytes
size_t max_mem = 0;
for (tensor_decl_ptr node : liveness_live_list)
{
size_t offset = node->size + node->buffer_pool_offset;
max_mem = std::max(offset, max_mem);
}
return max_mem;
}
size_t ExOp::memory_efficiency()
{
size_t mem = 100;
if (memory_footprint() > 0)
{
mem = round(float(memory_usage()) / float(memory_footprint()) * 100);
mem = size_t(mem);
}
return mem;
}
bool ExOp::is_exop_end_of_list()
{
// Returns:
// True if this represents the guard past the exop list. See ExOpBlock.
return false;
}
std::string ExOp::name() const
{
return op->name();
}
//================================================================================================
// ExOpBlock
//================================================================================================
ExOpBlock::ExOpBlock(ComputationDecl& cgraph)
: ExecutionGraphElt{cgraph.execution_graph}
, computation_graph{cgraph}
, root_set{}
{
}
bool ExOpBlock::is_exop_end_of_list()
{
// Returns:
// True if this represents the guard past the exop list. See ExecuteOp.
return true;
}
void ExOpBlock::add_ops(std::initializer_list<computation_op_ptr> roots, exop_ptr after_exop)
{
// Add exops needed to compute ops in roots.
// Args:
// roots: A collection of ops whose values are needed.
// after_exop: Where in the list to add the ops. Defaults to the end.
// Get computation graph ops that have already been computed
std::vector<op_ptr> computed_ops;
auto op_iterator = op_list.end();
if (after_exop)
{
op_iterator = find(op_list.begin(), op_list.end(), after_exop);
}
while (op_iterator != op_list.begin())
{
exop_ptr exop = *op_iterator--;
computed_ops.push_back(exop->op);
computed_ops.insert(computed_ops.end(), exop->ref_ops.begin(), exop->ref_ops.end());
for (InputDecl& arg : exop->args)
{
computed_ops.push_back(arg.exop.op);
}
for (InputDecl& arg : exop->args)
{
auto ref_ops = arg.value()->exop.ref_ops;
computed_ops.insert(computed_ops.end(), ref_ops.begin(), ref_ops.end());
}
}
std::vector<op_ptr> available;
std::map<op_ptr, size_t> counts;
std::map<op_ptr, std::vector<op_ptr>> parents;
std::vector<op_ptr> ready;
available.insert(available.end(), roots.begin(), roots.end());
while (available.size() > 0)
{
op_ptr op = available.back();
available.pop_back();
if (contains_key(counts, op) || contains(computed_ops, op))
{
continue;
}
std::vector<op_ptr> children;
for (op_ptr child : op->all_deps())
{
if (!contains(computed_ops, child))
{
children.push_back(child);
}
}
// children = OrderedSet((child for child in op.all_deps if child not in computed_ops))
if (children.size() > 0)
{
counts[op] = children.size();
for (op_ptr child : children)
{
parents[child].push_back(op);
available.push_back(child);
}
}
else
{
ready.push_back(op);
}
}
while (ready.size() > 0)
{
op_ptr op = ready.back();
ready.pop_back();
after_exop = add_op(op, after_exop = after_exop);
for (op_ptr p : parents[op])
{
size_t count = counts[p] - 1;
if (count == 0)
{
ready.push_back(p);
counts.erase(p);
}
else
{
counts[p] = count;
}
}
}
if (counts.size() > 0)
{
throw std::runtime_error("Graph not a DAG");
}
}
exop_ptr ExOpBlock::add_op(op_ptr op, exop_ptr after_exop)
{
// Add an exop for op to be executed after after_exop.
// Args:
// op: The op.
// after_exop: The exop to precede op.
// Returns:
// The new last op. If the op is executable, it will be the added exop,
// othwerwise the previous after_exop.
if (after_exop == nullptr)
{
after_exop = op_list.back();
}
if (op->is_sequencing_op())
{
return after_exop;
}
exop_ptr exop = std::make_shared<ExOp>(computation_graph, op);
return add_exop(exop, after_exop);
}
exop_ptr ExOpBlock::add_exop(exop_ptr exop, exop_ptr after_exop)
{
// Add exop to the list of exops, after after_exop.
// Args:
// exop:
// The exop to add.
// after_exop:
// If specified, the exop that should be added after after_exop. Defaults to the
// last exop added.
// Returns:
// The exop.
if (after_exop == nullptr)
{
op_list.push_back(exop);
}
else
{
auto it = find(op_list.begin(), op_list.end(), after_exop);
if (it == op_list.end())
{
throw std::runtime_error("exop not found in op_list");
}
// list::insert inserts BEFORE the op, we want after so increment iterator
it++;
op_list.insert(it, exop);
}
return exop;
}
void ExOpBlock::move_exop_to_after_exop(exop_ptr exop, exop_ptr after_exop)
{
auto it = find(op_list.begin(), op_list.end(), exop);
op_list.erase(it);
add_exop(exop, after_exop);
}
void ExOpBlock::remove_exop(exop_ptr exop)
{
auto it = find(op_list.begin(), op_list.end(), exop);
op_list.erase(it);
for (InputDecl& arg : exop->args)
{
arg.value()->value_users.erase(&arg);
}
}
// void ExOpBlock::replace_op(op_ptr old_op, op_ptr new_op)
// {
// // TODO Replacing an op can remove ops. For example, (x + 2) * 1 -> x + 2
// // replaces the * with +, so * and 1 drop out
// // 1 dropping out means one less constant tensor, if it's not used
// // anywhere else
// // * dropping out means a change to sequencing.
// new_op = as_op(new_op)
// old_exop = computation_graph.get_exop(old_op)
// new_exop = computation_graph.get_exop(new_op, None)
// if (new_exop == nullptr)
// {
// add_ops([new_op], after_exop=old_exop.prev_exop)
// new_exop = computation_graph->get_exop(new_op, None)
// }
// replace_users(old_exop, new_exop)
// remove_exop(old_exop)
// }
void ExOpBlock::replace_users(exop_ptr old_exop, exop_ptr new_exop)
{
// // Replace all users of old_exop with new_exop.
// // Args:
// // old_exop: The original exop.
// // new_exop: The replacment exop.
// for (int i=0; i<old_exop->values.size(); i++)
// {
// OutputDecl* old_value = &old_exop->values[i];
// OutputDecl* new_value = &new_exop->values[i];
// replace_value(old_value, new_value);
// }
// for (op_ptr op : old_exop->ref_ops)
// {
// new_exop->add_ref_op(op);
// }
// computation_graph.ops[old_exop->op] = new_exop;
}
// void ExOpBlock::replace_value(OutputDecl* old_value, OutputDecl* new_value)
// {
// for (InputDecl* value_user : old_value->value_users)
// {
// value_user->value(*new_value);
// }
// new_value->tensor_decl()->merge_flags(*old_value->tensor_decl());
// old_value->exop.values[old_value->pos] = *new_value;
// }
void ExOpBlock::replace_exop(exop_ptr old_exop, exop_ptr new_exop)
{
// add_exop(new_exop, old_exop->prev_exop);
// This SHOULD be the same as above
add_exop(new_exop, old_exop);
replace_users(old_exop, new_exop);
remove_exop(old_exop);
}
void ExOpBlock::merge_exop(exop_ptr old_exop, exop_ptr new_exop)
{
// new_exop, which should already exist, takes over for old_exop.
// Args:
// old_exop:
// new_exop:
replace_users(old_exop, new_exop);
remove_exop(old_exop);
}
size_t ExOpBlock::memory_footprint()
{
size_t max_mem = 0;
for (exop_ptr exop : *this)
{
max_mem = std::max(exop->memory_footprint(), max_mem);
}
return max_mem;
}
size_t ExOpBlock::worst_case_footprint()
{
size_t mem = 0;
for (OutputDecl* value : get_temp_vars())
{
mem += value->write_view()->tensor_decl.size;
}
return mem;
}
size_t ExOpBlock::memory_efficiency()
{
size_t footprint = memory_footprint();
size_t usage = 0;
for (exop_ptr exop : op_list)
{
usage = std::max(usage, exop->memory_usage());
}
size_t result = 100;
if (footprint > 0)
{
result = int(round((float(usage) / float(footprint)) * 100));
}
return result;
}
size_t ExOpBlock::persistent_size()
{
size_t mem = 0;
for (OutputDecl* value : get_persistent_vars())
{
mem += value->write_view()->tensor_decl.size;
}
return mem;
}
std::set<OutputDecl*> ExOpBlock::get_vars()
{
std::set<OutputDecl*> vars;
for (exop_ptr exop : op_list)
{
for (InputDecl& value : exop->args)
{
vars.insert(value.value());
}
for (OutputDecl& value : exop->values)
{
vars.insert(&value);
}
}
return vars;
}
std::set<OutputDecl*> ExOpBlock::get_temp_vars()
{
std::set<OutputDecl*> result;
for (OutputDecl* value : get_vars())
{
if (value->write_view()->tensor_decl.is_persistent == false)
{
result.insert(value);
}
}
return result;
}
std::set<OutputDecl*> ExOpBlock::get_persistent_vars()
{
std::set<OutputDecl*> result;
for (OutputDecl* value : get_vars())
{
if (value->write_view()->tensor_decl.is_persistent)
{
result.insert(value);
}
}
return result;
}
//================================================================================================
// TensorDecl
//================================================================================================
TensorDecl::TensorDecl(ExecutionGraph& eg,
ElementType _element_type,
size_t _size,
bool _is_persistent,
bool _is_input,
tensor_description_ptr _tensor_description_base,
bool _is_output,
bool _is_constant,
tensor_description_ptr tensor_description,
bool _is_compile_only)
: ExecutionGraphElt{eg}
, element_type{_element_type}
, size{_size}
, is_persistent{_is_persistent}
, is_input{_is_input}
, is_output{_is_output}
, buffer_pool_offset{0}
, tensor_view_decls{}
, tensor_description_base{_tensor_description_base}
, lifespan{0}
, is_constant{_is_constant}
, is_compile_only{_is_compile_only}
, initial_value{nullptr}
, source_tensor{this}
{
// TODO: fix this somehow
// if (tensor_description == nullptr)
// {
// if (op == nullptr)
// {
// tensor_description = tensor_description_base;
// }
// else
// {
// if (op->tensor()->is_state_op())
// {
// initial_value = op->tensor()->initial_value;
// }
// tensor_description = op->tensor_description();
// }
// }
// // TODO Needed for initialization. Use exop value instead.
// add_value(this, tensor_description)
}
tensor_view_decl_ptr
TensorDecl::get_tensor_view(tensor_description_ptr tdesc, InputDecl* reader, OutputDecl* writer)
{
tensor_view_decl_ptr tensor_view;
if (tdesc == nullptr)
{
tdesc = tensor_description_base;
}
tensor_view = tensor_view_decls[tdesc->axes_key()];
if (tensor_view == nullptr)
{
tensor_view = std::make_shared<TensorViewDecl>(*this, tdesc, execution_graph);
tensor_view_decls[tdesc->axes_key()] = tensor_view;
}
if (reader == nullptr)
{
tensor_view->readers.insert(reader);
}
if (writer != nullptr)
{
tensor_view->writers.insert(writer);
}
return tensor_view;
}
tensor_view_decl_ptr TensorDecl::get_tensor_view(tensor_description_ptr tdesc, InputDecl* reader)
{
return get_tensor_view(tdesc, reader, nullptr);
}
tensor_view_decl_ptr TensorDecl::get_tensor_view(tensor_description_ptr tdesc, OutputDecl* writer)
{
return get_tensor_view(tdesc, nullptr, writer);
}
void TensorDecl::merge_flags(const TensorDecl& tensor)
{
is_input |= tensor.is_input;
is_persistent |= tensor.is_persistent;
is_output |= tensor.is_output;
}
tensor_description_ptr TensorDecl::buffer_key()
{
// Returns: The key that makes this tensor unique in a buffer.
return tensor_description_base;
}
std::string TensorDecl::prefix()
{
std::stringstream ss{"_a"};
ss << "a_";
if (!is_persistent)
{
ss << execution_graph.computation_decl->computation_op->name();
}
return ss.str();
}
std::string TensorDecl::variable_name()
{
std::stringstream ss;
ss << prefix() << "_" << tensor_name();
return ss.str();
}
std::string TensorDecl::tensor_name()
{
// Returns: Name used for the tensor.
return tensor_description_base->name();
}
std::string TensorDecl::buffer_name()
{
// Returns: Name used for the buffer.
return tensor_description_base->name();
}
// std::string TensorDecl::name()
// {
// return op->name();
// }
std::ostream& ngraph::operator<<(std::ostream& out, const TensorDecl& obj)
{
out << obj.tensor_description_base->name();
return out;
}
//================================================================================================
// TensorViewDecl
//================================================================================================
TensorViewDecl::TensorViewDecl(TensorDecl& _tensor_decl,
tensor_description_ptr _tensor_description,
ExecutionGraph& eg)
: ExecutionGraphElt{eg}
, tensor_decl{_tensor_decl}
, tensor_description{_tensor_description}
, readers{}
, writers{}
, value{nullptr}
{
// self.value = None
}
std::string TensorViewDecl::name() const
{
std::stringstream ss;
ss << tensor_decl.variable_name() << "_v_" << tensor_description->name();
ss << "_" << join(tensor_description->shape(), "x");
return ss.str();
// shape_str = "x".join((str(_) for _ in tensor_description.shape))
// return "{}_v_{}_{}".format(self.tensor_decl.variable_name,
// self.tensor_description.name,
// shape_str)
}
// op_ptr TensorViewDecl::op()
// {
// return tensor_decl->op;
// }
tensor_view_decl_ptr TensorViewDecl::get_tensor_view(tensor_description_ptr _tensor_description,
InputDecl* _reader,
OutputDecl* _writer)
{
return tensor_decl.get_tensor_view(_tensor_description, _reader, _writer);
}
tensor_view_decl_ptr TensorViewDecl::get_tensor_view(tensor_description_ptr _tensor_description,
InputDecl* _reader)
{
return tensor_decl.get_tensor_view(_tensor_description, _reader, nullptr);
}
tensor_view_decl_ptr TensorViewDecl::get_tensor_view(tensor_description_ptr _tensor_description,
OutputDecl* _writer)
{
return tensor_decl.get_tensor_view(_tensor_description, nullptr, _writer);
}
//================================================================================================
// ComputationDecl
//================================================================================================
ComputationDecl::ComputationDecl(ExecutionGraph& eg, computation_op_ptr op)
: ExecutionGraphElt{eg}
, computation_op{op}
{
exop_block = std::make_shared<ExOpBlock>(*this);
exop_block->add_ops({computation_op});
// returns = std::make_shared<ReturnExOp>(*this);
auto return_op = std::make_shared<ReturnOp>();
returns = std::make_shared<ExOp>(*this, return_op);
// Get the exops we need values for, so that if they are computed at compile-time we still
// have a view to their value.
for (op_ptr co : computation_op->values())
{
if (co->is_tensor_op())
{
exop_block->root_set.insert(get_exop(co));
}
}
for (op_ptr co : computation_op->values())
{
if (co->is_tensor_op())
{
exop_block->root_set.insert(get_exop(op));
}
}
for (ExOp* e : exop_block->root_set)
{
for (OutputDecl& value : e->values)
{
InputDecl& arg = returns->add_arg(value);
op_returns[e->op.get()] = &arg;
op_returns[e->op->tensor().get()] = &arg;
value.write_view()->tensor_decl.is_output = true;
}
}
for (op_ptr co : computation_op->values())
{
if (co->tensor()->is_tensor_op())
{
values.insert(get_exop(op));
}
}
}
tensor_decl_ptr ComputationDecl::get_tensor_decl(op_ptr _op)
{
return execution_graph.get_tensor_decl(_op);
}
ExOp* ComputationDecl::get_exop(op_ptr _op)
{
op_ptr original_op = _op;
_op = _op->effective_tensor_op();
if (_op->is_state_op())
{
throw std::runtime_error("Use get_tensor for AssignableTensorOp");
}
ExOp* exop = ops[_op];
if (exop != nullptr)
{
return exop;
}
// if (default_value != _default_default)
// {
// return default_value;
// }
std::stringstream ss;
ss << "Unhandled op: " << original_op;
throw std::runtime_error(ss.str());
}
//================================================================================================
// ExecutionState
//================================================================================================
ExecutionState::ExecutionState(transformer_ptr transformer)
: __transformer{transformer}
, __tensors_decls{}
{
}
transformer_ptr ExecutionState::transformer()
{
return __transformer;
}
execution_graph_ptr ExecutionState::make_execution_graph(computation_op_ptr computation_op)
{
return std::make_shared<ExecutionGraph>(*this, computation_op);
}
tensor_decl_ptr ExecutionState::get_op_tensor(op_ptr op)
{
tensor_description_ptr tensor_description = op->tensor_description();
tensor_description_ptr tensor_description_base = tensor_description->base();
return __tensors_decls[tensor_description_base];
}
tensor_decl_ptr ExecutionState::ensure_tensor_decl(ExecutionGraph& execution_graph,
tensor_description_ptr tensor_description,
op_ptr op)
{
tensor_description_ptr tensor_description_base = tensor_description->base();
tensor_decl_ptr tensor_decl = __tensors_decls[tensor_description_base];
if (tensor_decl == nullptr)
{
bool is_output = false;
bool is_constant = false;
bool is_compile_only = false;
tensor_decl = std::make_shared<TensorDecl>(execution_graph,
tensor_description_base->element_type(),
tensor_description_base->tensor_size(),
tensor_description_base->is_persistent(),
tensor_description_base->is_input(),
tensor_description_base,
is_output,
is_constant,
nullptr,
is_compile_only);
__tensors_decls[tensor_description_base] = tensor_decl;
}
return tensor_decl;
}
//================================================================================================
// ExecutionGraph
//================================================================================================
ExecutionGraph::ExecutionGraph(ExecutionState& es, computation_op_ptr computation_op)
: execution_state{es}
, tensor_decls{}
, computation_decl{std::make_shared<ComputationDecl>(*this, computation_op)}
{
}
tensor_decl_ptr ExecutionGraph::get_tensor_decl(op_ptr op,
tensor_description_ptr tensor_description)
{
if (tensor_description == nullptr)
{
tensor_description = op->tensor_description();
}
tensor_description_ptr tensor_description_base = tensor_description->base();
if (tensor_description_base->is_persistent())
{
return execution_state.ensure_tensor_decl(*this, tensor_description, op);
}
tensor_decl_ptr tensor_decl = tensor_decls[tensor_description_base];
if (tensor_decl == nullptr)
{
bool is_output = false;
bool is_constant = false;
bool is_compile_only = false;
tensor_decl = std::make_shared<TensorDecl>(*this,
tensor_description_base->element_type(),
tensor_description_base->tensor_size(),
tensor_description_base->is_persistent(),
tensor_description_base->is_input(),
tensor_description_base,
is_output,
is_constant,
nullptr,
is_compile_only);
tensor_decls[tensor_description_base] = tensor_decl;
}
return tensor_decl;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include "axes.hpp"
#include "mock.hpp"
#include "op_graph.hpp"
namespace ngraph
{
// forward declaration. This will hopefully go away
class ExecutionGraph;
class TensorDescription;
class InputDecl;
class OutputDecl;
class TensorDecl;
class TensorViewDecl;
class ExOp;
class Op;
class ComputationDecl;
class ExOpBlock;
class ExecutionState;
using output_decl_ptr = std::shared_ptr<OutputDecl>;
using input_decl_ptr = std::shared_ptr<InputDecl>;
using tensor_decl_ptr = std::shared_ptr<TensorDecl>;
using tensor_view_decl_ptr = std::shared_ptr<TensorViewDecl>;
using exop_ptr = std::shared_ptr<ExOp>;
using computation_decl_ptr = std::shared_ptr<ComputationDecl>;
using execution_graph_ptr = std::shared_ptr<ExecutionGraph>;
using exop_block_ptr = std::shared_ptr<ExOpBlock>;
using tensor_ptr = std::shared_ptr<TensorInterface>;
using transformer_ptr = std::shared_ptr<Transformer>;
using execution_state_ptr = std::shared_ptr<ExecutionState>;
//================================================================================================
// OutputDecl
// One value computed by an exop
//
// Arguments:
// exop: The exop.
// pos: The position of the value, defaults to 0.
// tensor_description: Tensor description of the value.
// write_view: The tensor view where the value is written.
//
// Attributes:
// exop: The exop.
// pos: The position of the value.
// tensor_description: Tensor description of the value.
// write_view: The tensor view where the value is written.
// value_users: Arguments using this value.
//================================================================================================
class OutputDecl
{
public:
OutputDecl(const ExOp& _exop, size_t _pos, tensor_decl_ptr, tensor_description_ptr);
tensor_decl_ptr tensor_decl();
void tensor_decl(tensor_decl_ptr tensor_decl);
tensor_view_decl_ptr write_view();
void write_view(tensor_view_decl_ptr view);
friend std::ostream& operator<<(std::ostream& out, const OutputDecl& obj);
// def __repr__()
// {
// return "Val({exop}:{pos})".format(exop=self.exop.name, pos=self.pos)
// }
bool is_tensor_op() const;
const ExOp& exop;
size_t pos;
tensor_description_ptr tensor_description;
tensor_decl_ptr __tensor;
tensor_view_decl_ptr __write_view;
std::set<InputDecl*> value_users;
};
//================================================================================================
// InputDecl
// An argument for an exop.
//
// Arguments:
// exop: The exop.
// pos: The position of the value, defaults to 0.
// tensor_description: Tensor description of the value.
// read_view: The tensor view where the value is read from.
//
// Attributes:
// exop: The exop.
// pos: The position of the value.
// tensor_description: Tensor description of the value.
// read_view: The tensor view where the value is read from.
// value: Arguments supplying this value.
//================================================================================================
class InputDecl
{
public:
InputDecl(const ExOp& _exop,
size_t _pos,
tensor_description_ptr _tensor_description,
OutputDecl* _value);
TensorDecl& tensor_decl();
OutputDecl* value();
const OutputDecl* value() const;
void value(OutputDecl* value);
friend std::ostream& operator<<(std::ostream& out, const InputDecl& obj);
const ExOp& exop;
size_t pos;
tensor_description_ptr tensor_description;
tensor_view_decl_ptr read_view;
OutputDecl* m_value;
};
//================================================================================================
// ExecutionGraphElt
// An element of an exection graph.
//
// Arguments:
// execution_graph: The execution graph that indexes this exop.
//
// Attributes:
// execution_graph: The execution graph that indexes this exop.
//================================================================================================
class ExecutionGraphElt
{
public:
ExecutionGraphElt(ExecutionGraph& eg)
: execution_graph{eg}
{
}
ExecutionGraph& execution_graph;
};
//================================================================================================
// ExOp
//================================================================================================
class ExOp : public ExecutionGraphElt
{
public:
// An exop that indicates an op to be executed.
// The op might be different from what was originally found in the computation graph.
// The args are exops that reflect the current version of the graph, and may differ
// from the exops of the op's args.
// The views_in are the current tensor views for the args.
// The views_out are the current tensor views for any results.
// Arguments:
// op: The op to execute.
// Parameters:
// op: The computation graph op.
// views_in: Tensor views of the args.
// views_out: Tensor views of the result.
// Attributes:
// op: The computation graph op to execute.
// args: exops for the arguments.
// views_in: Views for the arguments.
// views_out: Views for the results.
// tensor: Tensor of the primary output.
// tensor_view: View of the primary output.
// ref_ops: All computation graph ops covered by this op
// op_map: A map from ops to ref ops, sha
ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value = true);
friend std::ostream& operator<<(std::ostream& out, const ExOp& obj);
// factory methods to make exops
static exop_ptr literal_scalar_exop(scalar_t scalar, ComputationDecl& computation_graph);
// A node in the graph, with inputs and outputs.
InputDecl& add_arg(OutputDecl& value, tensor_description_ptr tensor_description = nullptr);
InputDecl& add_write_arg(OutputDecl& value,
tensor_description_ptr tensor_description = nullptr);
OutputDecl& add_value(tensor_decl_ptr tensor_decl,
tensor_description_ptr tensor_description = nullptr);
op_ptr get_op();
void set_op(op_ptr _op);
void add_ref_op(op_ptr _op);
size_t memory_usage();
size_t memory_footprint();
size_t memory_efficiency();
bool is_exop_end_of_list();
std::string name() const;
ComputationDecl& computation_graph;
tensor_decl_ptr tensor_decl;
tensor_view_decl_ptr tensor_view;
std::vector<op_ptr> ref_ops;
op_ptr op;
std::vector<tensor_decl_ptr> liveness_live_list;
std::vector<tensor_decl_ptr> liveness_free_list;
std::vector<tensor_decl_ptr> liveness_new_list;
std::vector<InputDecl> args;
std::vector<InputDecl*>
write_args; // TODO: Kludge until we have values with writers/readers
std::vector<OutputDecl> values;
};
//================================================================================================
// TensorDecl
//================================================================================================
class TensorDecl : public ExecutionGraphElt
{
public:
// Allocate for a tensor.
// Arguments:
// op: The AllocateTensorOp
// element_type: The type of the elements.
// size: The number of elements.
// is_persistent: True if the tensor is persistent.
// is_input: True if the tensor can be used as an argument.
// tensor_description_base: The base tensor description for the tensor.
// source_tensor: For a clone, the tensor that started the chain of clones
// this tensor is cloned from.
// Parameters:
// op: The AllocateTensorOp
// element_type: The type of the elements.
// size: The number of elements.
// is_persistent: True if the tensor is persistent.
// is_input: True if the tensor can be used as an argument.
// is_output: True if the tensor needs to be available for output. Defaults to is_persistent.
// tensor_descriptions: The set of tensor descriptions for the tensor.
// tensor_description_base: The tensor description base for this tensor.
// is_compile_only: If True, this tensor is only needed during compilation, and should not be
// allocated.
TensorDecl(ExecutionGraph&,
ElementType,
size_t,
bool _is_persistent,
bool _is_input,
tensor_description_ptr,
bool _is_output,
bool _is_constant,
tensor_description_ptr tensor_description,
bool _is_compile_only);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr,
InputDecl* reader = nullptr,
OutputDecl* writer = nullptr);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr,
InputDecl* reader = nullptr);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr tensor_description = nullptr,
OutputDecl* writer = nullptr);
void merge_flags(const TensorDecl& tensor);
tensor_description_ptr buffer_key();
std::string prefix();
std::string variable_name();
std::string tensor_name();
std::string buffer_name();
// std::string name();
friend std::ostream& operator<<(std::ostream& out, const TensorDecl& obj);
// op_ptr op;
ElementType element_type;
size_t size;
bool is_persistent;
bool is_input;
bool is_output;
size_t buffer_pool_offset;
std::map<axes_key_t, tensor_view_decl_ptr> tensor_view_decls;
tensor_description_ptr tensor_description_base;
size_t lifespan;
bool is_constant;
bool is_compile_only;
tensor_ptr initial_value;
tensor_decl_ptr source_tensor;
};
//================================================================================================
// ExOpBlock
//================================================================================================
class ExOpBlock : public ExecutionGraphElt
{
public:
// Sequentially execute a list of exops.
// Attributes:
// computation_graph: The associated computation graph.
// prev_exop: The latst exop.
// next_exop: The first exop.
// root_set: Set of exops whose values are needed.
ExOpBlock(ComputationDecl& cgraph);
bool is_exop_end_of_list();
void add_ops(std::initializer_list<computation_op_ptr> roots,
exop_ptr after_exop = nullptr);
exop_ptr add_op(op_ptr op, exop_ptr after_exop);
exop_ptr add_exop(exop_ptr exop, exop_ptr after_exop = nullptr);
void move_exop_to_after_exop(exop_ptr exop, exop_ptr after_exop);
void remove_exop(exop_ptr exop);
void replace_op(op_ptr old_op, op_ptr new_op);
void replace_users(exop_ptr old_exop, exop_ptr new_exop);
void replace_value(OutputDecl* old_value, OutputDecl* new_value);
void replace_exop(exop_ptr old_exop, exop_ptr new_exop);
void merge_exop(exop_ptr old_exop, exop_ptr new_exop);
size_t memory_footprint();
size_t worst_case_footprint();
size_t memory_efficiency();
size_t persistent_size();
std::set<OutputDecl*> get_vars();
std::set<OutputDecl*> get_temp_vars();
std::set<OutputDecl*> get_persistent_vars();
ComputationDecl& computation_graph;
std::set<ExOp*> root_set;
// replacement for next_exop, prev_exop
std::list<exop_ptr>::iterator begin() { return op_list.begin(); }
std::list<exop_ptr>::iterator end() { return op_list.end(); }
std::list<exop_ptr> op_list;
};
//================================================================================================
// TensorViewDecl
//================================================================================================
class TensorViewDecl : public ExecutionGraphElt
{
public:
// Declare a view of a tensor.
// Arguments:
// tensor: The tensor.
// tensor_description: The description of the view.
TensorViewDecl(TensorDecl&, tensor_description_ptr, ExecutionGraph&);
std::string name() const;
// op_ptr op();
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, InputDecl*, OutputDecl*);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, InputDecl*);
tensor_view_decl_ptr get_tensor_view(tensor_description_ptr, OutputDecl*);
// def key()
// {
// """
// // Returns: A tuple unique to this view of the tensor.
// """
// return tensor_description->parameter_key
// }
TensorDecl& tensor_decl;
tensor_description_ptr tensor_description;
// initializers;
std::set<InputDecl*> readers;
std::set<OutputDecl*> writers;
OutputDecl* value;
};
// static exop_ptr _default_default;
//================================================================================================
// ComputationDecl
//================================================================================================
class ComputationDecl : public ExecutionGraphElt
{
public:
// One computation to be run.
// Every computation has its own execution graph. Persistent tensors are shared
// between computations, other tensors are not.
// Attributes:
// computation: The computation op.
// ops: A map from ops to the exop that handles the op in this computation.
// exop: The SSA block of exops for this computation.
// values: The ops whose values are returned from the computation.
// tensors: Map from base tensor descriptions to tensors.
ComputationDecl(ExecutionGraph& eg, computation_op_ptr op);
tensor_decl_ptr get_tensor_decl(op_ptr _op = nullptr);
ExOp* get_exop(op_ptr _op);
computation_op_ptr computation_op;
std::map<op_ptr, ExOp*> ops;
std::vector<tensor_decl_ptr> tensors;
std::map<Op*, InputDecl*> op_returns; // op_returns_anchor?
exop_block_ptr exop_block;
exop_ptr returns;
std::set<ExOp*> values;
};
//================================================================================================
// ExecutionState
//================================================================================================
class ExecutionState
{
public:
// Proxy for the state of a device.
// Arguments:
// transformer: The associated transformer.
ExecutionState(transformer_ptr transformer = nullptr);
transformer_ptr transformer();
execution_graph_ptr make_execution_graph(computation_op_ptr);
tensor_decl_ptr get_op_tensor(op_ptr op);
tensor_decl_ptr ensure_tensor_decl(ExecutionGraph&, tensor_description_ptr, op_ptr);
transformer_ptr __transformer;
// persistent tensors
std::map<tensor_description_ptr, tensor_decl_ptr> __tensors_decls;
};
//================================================================================================
// ExecutionGraph
//================================================================================================
class ExecutionGraph
{
public:
// Information for compiling a computation_op.
// Arguments:
// execution_state: The execution state the graph will be applied to. The definitons in
// the execution state can be used in the execution graph.
// computation_op: A computation to be processed
ExecutionGraph(ExecutionState& execution_state, computation_op_ptr computation_op);
tensor_decl_ptr get_tensor_decl(op_ptr, tensor_description_ptr = nullptr);
ExecutionState& execution_state;
// temporary tensors
std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls;
computation_decl_ptr computation_decl;
};
} // end namespace ngraph
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include "element_type.hpp"
namespace ngraph
{
class ExecutionState;
class Op;
// class TensorDescription;
class ComputationOp;
using computation_op_ptr = std::shared_ptr<ComputationOp>;
using op_ptr = std::shared_ptr<Op>;
using scalar_t = float;
//================================================================================================
// TensorInterface
//================================================================================================
class TensorInterface
{
public:
virtual ~TensorInterface() {}
virtual const ElementType& element_type() const = 0;
virtual std::string value_string() const = 0;
};
//================================================================================================
// Tensor
//================================================================================================
template <typename T>
class Tensor : public TensorInterface
{
public:
Tensor(const T& val)
: m_value{val}
, m_element_type{element_type_float}
{
}
virtual ~Tensor() {}
const ElementType& element_type() const override { return m_element_type; }
std::string value_string() const override
{
std::string rc = "WTF";
if (std::is_floating_point<T>::value)
{
std::stringstream ss;
ss << m_value;
rc = ss.str();
}
return rc;
}
private:
T m_value;
ElementType m_element_type;
};
//================================================================================================
// Transformer
//================================================================================================
class Transformer
{
public:
virtual ~Transformer() {}
virtual ExecutionState& execution_state() = 0;
};
//================================================================================================
// TensorDescription
//================================================================================================
// class TensorDescription
// {
// public:
// virtual ~TensorDescription();
// virtual axes_key_t axes_key() const = 0;
// virtual std::string name() const = 0;
// virtual std::vector<size_t> shape() const = 0;
// virtual std::shared_ptr<TensorDescription> base() = 0;
// virtual ElementType element_type() const = 0;
// virtual size_t tensor_size() = 0;
// virtual bool is_persistent() = 0;
// virtual bool is_input() = 0;
// };
//================================================================================================
// Op
//================================================================================================
// class Op
// {
// // Any operation that can be in an AST.
// // Arguments:
// // args: Values used by this node.
// // const: The value of a constant Op, or None,
// // constant (bool): The Op is constant. Default False.
// // forward: If not None, the node to use instead of this node.
// // metadata: String key value dictionary for frontend metadata.
// // kwargs: Args defined in related classes.
// // Attributes:
// // const: The value of a constant.
// // constant (bool): The value is constant.
// // control_deps (OrderedSet): Ops in addtion to args that must run before this op.
// // persistent (bool): The value will be retained from computation to computation and
// // not shared. Always True if reference is set.
// // metadata: Dictionary with of string keys and values used for attaching
// // arbitrary metadata to nodes.
// // trainable: The value is trainable.
// public:
// virtual ~Op() {}
// virtual std::string name() const = 0;
// virtual tensor_description_ptr tensor_description() = 0;
// virtual op_ptr tensor() = 0;
// virtual bool is_tensor_op() = 0;
// virtual bool is_state_op() const = 0;
// virtual bool is_sequencing_op() const = 0;
// virtual op_ptr effective_tensor_op() = 0;
// virtual const std::vector<op_ptr>& all_deps() const = 0;
// // ops
// // TODO support multiple types
// static op_ptr constant(float value)
// {
// op_ptr = make_shared<LiteralScalarOp>(value);
// }
// };
//================================================================================================
// TensorOp
//================================================================================================
// class TensorOp : public Op
// {
// public:
// std::string name() const override { return "TensorOp"; }
// tensor_description_ptr tensor_description() override { return nullptr; }
// op_ptr tensor() override { return nullptr; }
// bool is_tensor_op() override { return true; }
// bool is_state_op() const override { return false; }
// op_ptr effective_tensor_op() override { return nullptr; }
// const std::vector<op_ptr>& all_deps() const override { return m_all_deps; }
// private:
// std::vector<op_ptr> m_all_deps;
// };
} // end of namespace ngraph
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "mock_transformer.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "exop.hpp"
#include "mock.hpp"
namespace ngraph
{
//================================================================================================
// CpuTransformer
//================================================================================================
class CpuTransformer : public Transformer
{
public:
virtual ~CpuTransformer() {}
ExecutionState& execution_state() override { return m_execution_state; }
private:
ExecutionState m_execution_state;
};
} // end namespace ngraph
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ndarray.hpp"
ngraph::ndarray::ndarray(ElementType _dtype,
std::vector<size_t> _shape,
std::shared_ptr<char> _buffer,
size_t _offset,
const tensor_stride& _strides)
: dtype{_dtype}
, shape{_shape}
, buffer{_buffer}
, strides{_strides}
, offset{_offset}
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "element_type.hpp"
#include "strides.hpp"
namespace ngraph
{
class ndarray;
}
class ngraph::ndarray
{
public:
ndarray(ElementType dtype = element_type_float,
std::vector<size_t> shape = std::vector<size_t>(),
std::shared_ptr<char> data = nullptr,
size_t offset = 0,
const tensor_stride& strides = tensor_stride());
ElementType dtype;
std::vector<size_t> shape;
std::shared_ptr<char> buffer;
tensor_stride strides;
size_t offset;
};
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
#include "transformers/axes.hpp"
#include "transformers/ndarray.hpp"
using namespace std;
using namespace ngraph;
// axes for testing
static auto ax_A = make_axis(2, "A");
static auto ax_B = make_axis(3, "B");
static auto ax_C = make_axis(4, "C");
// axes for testing name matching behavior
static auto ax_A_ = make_axis(5, "A");
static auto ax_B_ = make_axis(6, "B");
static auto ax_C_ = make_axis(7, "C");
//=================================================================================================
// random
// return a random numpy array with dimension and dtype specified by
// tensor_description.
//
// Arguments:
// tensor_description: location of dimension and dtype specifications for
// returned array.
//=================================================================================================
ngraph::ndarray random(const TensorDescription& td)
{
ngraph::ndarray result{td.dtype, td.shape()};
// return np.random.random(
// tensor_description.shape
// ).astype(tensor_description.dtype)
return result;
}
//=================================================================================================
// tensorview
// Returns a numpy array which whose buffer is nparr using the
// tensordescription in td
//
// Arguments:
// td TensorDescription: the description of the view of the nparr buffer
// nparr: the memory the np.array should use
//
// Returns:
// np.array view of nparr
//=================================================================================================
ngraph::ndarray tensorview(const TensorDescription& td, ngraph::ndarray& nparr)
{
ngraph::ndarray result{td.dtype, td.shape(), nparr.buffer};
// return np.ndarray(
// shape=td.shape,
// dtype=td.dtype,
// buffer=nparr,
// offset=td.offset,
// strides=td.strides
// )
return result;
}
void compute_eq(const std::vector<Axis>& _lhs, const std::vector<Axis>& _rhs, bool expected)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
bool actual = lhs_axes == rhs_axes;
EXPECT_EQ(expected, actual);
}
void compute_ne(const std::vector<Axis>& _lhs, const std::vector<Axis>& _rhs, bool expected)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
bool actual = lhs_axes != rhs_axes;
EXPECT_EQ(expected, actual);
}
void compute_add(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
const std::vector<Axis>& _expected,
bool expect_failure = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
Axes expected = make_axes(_expected);
if (expect_failure)
{
EXPECT_THROW((lhs_axes + rhs_axes), std::invalid_argument);
}
else
{
Axes actual = lhs_axes + rhs_axes;
EXPECT_EQ(expected, actual);
}
}
void compute_subtract(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
const std::vector<Axis>& _expected,
bool expect_failure = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
Axes expected = make_axes(_expected);
Axes actual = lhs_axes - rhs_axes;
EXPECT_EQ(expected, actual);
}
void compute_or(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
const std::vector<Axis>& _expected,
bool expect_failure = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
Axes expected = make_axes(_expected);
Axes actual = lhs_axes | rhs_axes;
EXPECT_EQ(expected, actual);
}
void compute_and(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
const std::vector<Axis>& _expected,
bool expect_failure = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
Axes expected = make_axes(_expected);
Axes actual = lhs_axes & rhs_axes;
EXPECT_EQ(expected, actual);
}
void compute_subset(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
bool expected = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
bool actual = rhs_axes.is_sub_set(lhs_axes);
EXPECT_EQ(expected, actual);
}
void compute_superset(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
bool expected = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
bool actual = rhs_axes.is_super_set(lhs_axes);
EXPECT_EQ(expected, actual);
}
void compute_eq_set(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
bool expected = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
bool actual = lhs_axes.is_equal_set(rhs_axes);
EXPECT_EQ(expected, actual);
}
void compute_ne_set(const std::vector<Axis>& _lhs,
const std::vector<Axis>& _rhs,
bool expected = false)
{
Axes lhs_axes = make_axes(_lhs);
Axes rhs_axes = make_axes(_rhs);
bool actual = lhs_axes.is_not_equal_set(rhs_axes);
EXPECT_EQ(expected, actual);
}
TEST(axes, eq)
{
compute_eq({}, {}, true);
compute_eq({ax_A}, {}, false);
compute_eq({ax_A, ax_B}, {ax_B, ax_A}, false);
compute_eq({ax_A, ax_B}, {ax_B_, ax_A_}, false);
compute_eq({ax_A, ax_B}, {ax_A_, ax_B}, true);
compute_eq({ax_A, ax_B}, {ax_A_, ax_B_}, true);
}
TEST(axes, ne)
{
compute_ne({}, {}, false);
compute_ne({ax_A}, {}, true);
compute_ne({ax_A, ax_B}, {ax_B, ax_A}, true);
compute_ne({ax_A, ax_B}, {ax_B_, ax_A_}, true);
compute_ne({ax_A, ax_B}, {ax_A_, ax_B}, false);
compute_ne({ax_A, ax_B}, {ax_A_, ax_B_}, false);
}
TEST(axes, add)
{
compute_add({}, {}, {});
compute_add({ax_A}, {}, {ax_A});
compute_add({ax_A_}, {}, {ax_A});
compute_add({ax_A}, {ax_B}, {ax_A, ax_B});
compute_add({ax_A_}, {ax_B_}, {ax_A, ax_B});
// add (list operation, test exception)
compute_add({ax_A}, {ax_A}, {}, true);
compute_add({ax_A}, {ax_A_}, {}, true);
compute_add({ax_A}, {ax_A_, ax_B}, {}, true);
}
TEST(axes, subtract)
{
compute_subtract({}, {}, {});
compute_subtract({}, {ax_A}, {});
compute_subtract({ax_A}, {}, {ax_A});
compute_subtract({ax_A, ax_B}, {ax_B}, {ax_A});
compute_subtract({ax_A, ax_B}, {ax_B_}, {ax_A});
compute_subtract({ax_A, ax_B}, {ax_A}, {ax_B});
compute_subtract({ax_A, ax_B}, {ax_A_}, {ax_B});
compute_subtract({ax_A, ax_B}, {ax_B, ax_A}, {});
compute_subtract({ax_A, ax_B}, {ax_B_, ax_A_}, {});
}
TEST(axes, or)
{
compute_or({}, {}, {});
compute_or({}, {ax_A}, {ax_A});
compute_or({ax_A}, {}, {ax_A});
compute_or({ax_A}, {ax_B}, {ax_A, ax_B});
compute_or({ax_A}, {ax_A_}, {ax_A});
compute_or({ax_A}, {ax_A_}, {ax_A_});
}
TEST(axes, and)
{
compute_and({}, {}, {});
compute_and({}, {ax_A}, {});
compute_and({ax_A}, {}, {});
compute_and({ax_A}, {ax_B}, {});
compute_and({ax_A, ax_B}, {ax_B, ax_C}, {ax_B});
compute_and({ax_A, ax_B_}, {ax_B, ax_C}, {ax_B});
}
TEST(axes, sub_set)
{
compute_subset({}, {}, true);
compute_subset({ax_A}, {}, false);
compute_subset({}, {ax_A}, true);
compute_subset({ax_A_}, {ax_A}, true);
compute_subset({ax_A, ax_B}, {ax_B, ax_A}, true);
compute_subset({ax_A, ax_B}, {ax_B_, ax_A_}, true);
}
TEST(axes, super_set)
{
compute_superset({}, {}, true);
compute_superset({ax_A}, {}, true);
compute_superset({}, {ax_A}, false);
compute_superset({ax_A_}, {ax_A}, true);
compute_superset({ax_A, ax_B}, {ax_B, ax_A}, true);
compute_superset({ax_A, ax_B}, {ax_B_, ax_A_}, true);
}
TEST(axes, eq_set)
{
compute_eq_set({}, {}, true);
compute_eq_set({ax_A}, {}, false);
compute_eq_set({ax_A}, {ax_A}, true);
compute_eq_set({ax_A}, {ax_A_}, true);
compute_eq_set({ax_A, ax_B}, {ax_B_, ax_A_}, true);
}
TEST(axes, ne_set)
{
compute_ne_set({}, {}, false);
compute_ne_set({ax_A}, {}, true);
compute_ne_set({ax_A}, {ax_A}, false);
compute_ne_set({ax_A}, {ax_A_}, false);
compute_ne_set({ax_A, ax_B}, {ax_B_, ax_A_}, false);
}
TEST(axes, index)
{
Axis C = make_axis(5, "C");
Axis H = make_axis(3, "H");
Axis N = make_axis(7, "N");
Axes a{C, H, N};
EXPECT_EQ(5, a[0].length());
EXPECT_EQ(3, a[1].length());
EXPECT_EQ(7, a[2].length());
Axes b{{C, H}, N};
EXPECT_EQ(15, b[0].length());
EXPECT_EQ(7, b[1].length());
}
TEST(axes, DISABLED_as_nested_list)
{
Axis C = make_axis(5);
Axis H = make_axis(3);
Axis N = make_axis(7);
Axes a{C, H, N};
cout << "a " << a << endl;
Axes b{{C, H}, N};
cout << "b " << b << endl;
FAIL();
}
TEST(axes, DISABLED_flatten)
{
Axis C = make_axis(5);
Axis H = make_axis(3);
Axis N = make_axis(7);
Axes b{{C, H}, N};
auto c = b.flatten();
EXPECT_TRUE(c.is_flattened());
}
TEST(axes, DISABLED_as_flattened_list)
{
FAIL();
}
// This test just has to compile
TEST(axes, hash_axis)
{
std::hash<Axis> h1;
std::hash<Axes> h2;
(void)h1;
(void)h2;
std::unordered_map<Axis, int> m1; // needs operator==
std::map<Axis, int> m2; // needs operator<
m1[ax_A] = 1;
m2[ax_A] = 1;
}
TEST(axes, hash_axes)
{
Axes axes = make_axes({ax_A, ax_B});
std::unordered_map<Axes, int> m1; // needs operator==
std::map<Axes, int> m2; // needs operator<
m1[axes] = 1;
m2[axes] = 1;
}
TEST(axes, DISABLED_reaxe_0d_to_1d)
{
TensorDescription td{};
ngraph::ndarray x = random(td);
// create view of x
// auto btd = td.broadcast({ax_A});
// auto x_view = tensorview(btd, x);
// # set x
// x[()] = 3
// # setting e also sets x_view
// assert x_view.shape == (ax_A.length,)
// assert np.all(x_view == 3)
FAIL();
}
TEST(axes, DISABLED_reaxe_0d_to_2d)
{
// td = TensorDescription(axes=())
// x = random(td)
// x_view = tensorview(td.broadcast([ax_A, ax_B]), x)
// # set x
// x[()] = 3
// assert x_view.shape == (ax_A.length, ax_B.length)
// assert np.all(x_view == 3)
FAIL();
}
//-----------------------------------------------------------------------------------------------
// tons of tests relating to reaxeing tensors.
//
// variables names have a postfix integer which represents the dimensionality
// of the value. Views have x_y postfix which means they are y dimensional
// views of x dimensional buffers.
//
// I started refactoring into smaller pieces as seen in tests above, but
// stopped ...
//-----------------------------------------------------------------------------------------------
TEST(axes, DISABLED_simple_tensors)
{
// # A simple vector
// td1 = TensorDescription(axes=[ax_A])
// e1 = random(td1)
// td2 = TensorDescription(axes=[ax_A, ax_B])
// e2 = random(td2)
// # Reaxes
// e1_1 = tensorview(td1.broadcast([ax_A, ax_B]), e1)
// e1_2 = tensorview(td1.broadcast([ax_B, ax_A]), e1)
// e1_3 = tensorview(td1.broadcast([(ax_B, ax_C), ax_A]), e1)
// e2_1 = tensorview(td2.broadcast([ax_B, ax_A]), e2)
// e2_2 = tensorview(td2.broadcast([ax_A, ax_B]), e2)
// e2_3 = tensorview(td2.flatten((
// FlattenedAxis((ax_A, ax_B)),
// )), e2_2)
// assert e1_1.shape == (ax_A.length, ax_B.length)
// assert e1_2.shape == (ax_B.length, ax_A.length)
// for i in range(ax_A.length):
// e1_1[i] = i
// for i in range(ax_A.length):
// assert e1[i] == i
// for j in range(ax_B.length):
// assert e1_1[i, j] == i
// assert e1_2[j, i] == i
// for j in range(ax_B.length * ax_C.length):
// assert e1_3[j, i] == i
// def val2(i, j):
// return (i + 1) * (j + 2)
// for i in range(ax_A.length):
// for j in range(ax_B.length):
// e2[i, j] = val2(i, j)
// for i in range(ax_A.length):
// for j in range(ax_B.length):
// assert e2_1[j, i] == val2(i, j)
// assert e2_2[i, j] == val2(i, j)
// assert e2_3[i * ax_B.length + j] == val2(i, j)
FAIL();
}
TEST(axes, sliced_axis)
{
auto a = make_axis(10);
auto s = slice_axis(a, slice(0, 5));
EXPECT_EQ(5, s.length());
}
TEST(axes, sliced_axis_invalid)
{
auto a = make_axis(10);
auto s = slice_axis(a, slice(5, 0));
EXPECT_EQ(0, s.length());
}
TEST(axes, sliced_axis_none_end)
{
auto a = make_axis(10);
auto s = slice_axis(a, slice(0));
EXPECT_EQ(10, s.length());
}
TEST(axes, sliced_axis_negative)
{
auto a = make_axis(10);
auto s = slice_axis(a, slice(5, 0, -1));
EXPECT_EQ(5, s.length());
}
TEST(axes, sliced_axis_negative_invalid)
{
auto a = make_axis(10);
auto s = slice_axis(a, slice(0, 5, -1));
EXPECT_EQ(0, s.length());
}
TEST(axes, sliced_axis_flip)
{
auto a = make_axis(10);
auto s = slice_axis(a, slice(-1, -1, -1));
EXPECT_EQ(0, s.length());
}
TEST(axes, sliced_axis_invalid_step)
{
EXPECT_THROW(slice(0, 5, 2), std::invalid_argument);
}
TEST(axes, sliced_batch_axis)
{
// slicing a batch axis should result in a batch axis
auto a = make_axis(10, "N");
ASSERT_TRUE(a.is_batch());
auto s = slice_axis(a, slice(0, 5));
EXPECT_TRUE(s.is_batch());
}
TEST(axes, sliced_recurrent_axis)
{
// slicing a recurrent axis should result in a recurrent axis
auto a = make_axis(10, "REC");
ASSERT_TRUE(a.is_recurrent());
auto s = slice_axis(a, slice(0, 5));
EXPECT_TRUE(s.is_recurrent());
}
TEST(axes, duplicate_axis_names)
{
try
{
AxesMap({{"aaa", "zzz"}, {"bbb", "zzz"}, {"ccc", "yyy"}});
FAIL();
}
catch (std::invalid_argument e)
{
EXPECT_TRUE(std::string(e.what()).find("aaa") != std::string::npos);
EXPECT_TRUE(std::string(e.what()).find("bbb") != std::string::npos);
EXPECT_TRUE(std::string(e.what()).find("zzz") != std::string::npos);
}
catch (...)
{
FAIL();
}
}
TEST(axes, invalid_axes_map_message)
{
try
{
AxesMap({{"aaa", "zzz"}, {"bbb", "zzz"}, {"ccc", "yyy"}});
FAIL();
}
catch (std::invalid_argument e)
{
EXPECT_TRUE(std::string(e.what()).find("aaa") != std::string::npos);
EXPECT_TRUE(std::string(e.what()).find("bbb") != std::string::npos);
EXPECT_TRUE(std::string(e.what()).find("zzz") != std::string::npos);
EXPECT_FALSE(std::string(e.what()).find("ccc") != std::string::npos);
EXPECT_FALSE(std::string(e.what()).find("yyy") != std::string::npos);
}
catch (...)
{
FAIL();
}
}
TEST(axes, axes_map)
{
// map from Axes([aaa, bbb]) to Axes([zzz, bbb]) via AxesMap {aaa: zzz}
auto a = make_axis(10, "aaa");
auto b = make_axis(10, "bbb");
auto z = make_axis(10, "zzz");
// axes_before = ng.make_axes([a, b])
auto axes_before = make_axes({a, b});
// axes_after = ng.make_axes([z, b])
auto axes_after = make_axes({z, b});
// axes_map = AxesMap({a.name: z.name})
AxesMap axes_map({a.name, z.name});
EXPECT_EQ(axes_after, axes_map.map_axes(axes_before));
// assert axes_after == axes_map.map_axes(axes_before)
}
TEST(axes, DISABLED_axes_map_immutable)
{
FAIL();
// axes_map = AxesMap({})
// with pytest.raises(TypeError):
// axes_map["x"] = "y"
}
TEST(axes, DISABLED_axes_map_init_from_axes)
{
FAIL();
// axes_map = AxesMap({ng.make_axis(1, name="aaa"): ng.make_axis(1, name="zzz")})
// assert axes_map["aaa"] == "zzz"
}
TEST(axes, duplicates)
{
auto a = make_axis(10, "aaa");
auto b = make_axis(10, "bbb");
auto z = make_axis(10, "zzz");
vector<Axis> a1{a, b, z};
vector<Axis> a2{a, b, b, z};
auto l1 = duplicates(a1);
auto l2 = duplicates(a2);
EXPECT_EQ(0, l1.size());
ASSERT_EQ(1, l2.size());
EXPECT_STREQ("bbb", l2[0].c_str());
}
TEST(tensor_description, broadcast)
{
// TensorDescription td1{};
// TensorDescription td2 = td1.broadcast({ax_A});
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "transformers/exop.hpp"
#include "transformers/mock.hpp"
#include "transformers/mock_transformer.hpp"
using namespace ngraph;
TEST(exop, create)
{
// CpuTransformer transformer;
// op_ptr c1 = constant(1.0);
// op_ptr a1 = add(c1, c1);
// std::vector<op_ptr> inputs;
// std::vector<op_ptr> outputs = {a1};
// auto computation_op = std::make_shared<ComputationOp>(inputs, outputs);
// ExecutionState& es = transformer.execution_state();
// execution_graph_ptr eg = es.make_execution_graph(computation_op); // one at a time
// computation_decl_ptr cd = eg->computation_decl;
// // transforer.run_passes(computation_decl_ptr);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "names.hpp"
using namespace ngraph;
TEST(names, name)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "transformers/op_graph.hpp"
using namespace ngraph;
TEST(op_graph, constant)
{
float expected_value = 42;
op_ptr x = constant(expected_value);
ASSERT_NE(nullptr, x);
EXPECT_EQ(true, x->is_constant());
EXPECT_EQ(false, x->is_input());
EXPECT_EQ(true, x->is_persistent());
EXPECT_EQ(false, x->is_trainable());
EXPECT_EQ(false, x->is_placeholder());
auto ato = std::dynamic_pointer_cast<AssignableTensorOp>(x);
ASSERT_NE(nullptr, ato);
// TODO: fix this
auto ti = ato->m_value;
ASSERT_NE(nullptr, ti);
std::string actual_value = ti->value_string();
std::stringstream ss;
ss << expected_value;
std::string expected_string = ss.str();
EXPECT_STREQ(actual_value.c_str(), expected_string.c_str());
}
// @pytest.fixture()
// def N():
Axis N()
{
// return ng.make_axis(length=1)
return make_axis(1);
}
// def test_deriv_missing_connection(N):
// """
// Taking the derivative of an expression with respect to a variable not
// used to compute the expression should raise an exception.
// """
TEST(op_graph, deriv_missing_connection)
{
// x = ng.variable([N])
// auto x = variable({N()});
// y = ng.variable([N])
// z = ng.variable([N])
// with pytest.raises(ValueError):
// ng.deriv(x + y, z)
}
// def test_one():
// # Test that the cacheing on constant one used in DerivOp works.
// op = ng.variable([])
// one_0 = op.one
// one_1 = op.one
// assert one_0 is one_1
// def test_pad_invalid_paddings_length(N):
// """
// pad should raise an exception if the paddings length is not the same as the
// input dimensionality.
// """
// x = ng.variable([N])
// with pytest.raises(ValueError):
// ng.pad(x, [1, 0])
// def test_pad_0(N):
// """
// pad with length 0 should be a nop
// """
// x = ng.variable([N])
// assert ng.pad(x, [0]).axes == x.axes
// def test_pad_mixed():
// """
// mix 0 padding with non-0 padding
// """
// input_axes = ng.make_axes([
// ng.make_axis(1),
// ng.make_axis(1)
// ])
// x = ng.variable(input_axes)
// pad = ng.pad(x, [0, 1])
// assert pad.axes[0] == x.axes[0]
// assert pad.axes[1] != x.axes[1]
// def test_slice_nop():
// """
// slicing an axis shouldn't change the name
// """
// input_axes = ng.make_axes([
// ng.make_axis(1),
// ng.make_axis(1)
// ])
// x = ng.variable(input_axes)
// s = ng.tensor_slice(x, [
// slice(None, None, None),
// slice(None, None, 1),
// ])
// assert s.axes[0] == x.axes[0]
// assert s.axes[1] == x.axes[1]
// def test_tensor_slice():
// """
// slicing a tensor should work like numpy
// """
// input_axes = ng.make_axes([
// ng.make_axis(10),
// ng.make_axis(20),
// ng.make_axis(5)
// ])
// x = ng.placeholder(axes=input_axes)
// assert x[:5].axes.full_lengths == (5, 20, 5)
// assert x[:, 2:7].axes.full_lengths == (10, 5, 5)
// assert x[:5, :, :-1].axes.full_lengths == (5, 20, 4)
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
#include "strides.hpp"
using namespace std;
using namespace ngraph;
TEST(strides, scalar_tree_ctor)
{
{
ngraph::scalar_tree tree{2, 3, 4};
stringstream ss;
{
ss << tree;
EXPECT_STREQ("(2, 3, 4)", ss.str().c_str());
}
}
{
ngraph::scalar_tree tree{{2, 3}, 4};
stringstream ss;
{
ss << tree;
EXPECT_STREQ("((2, 3), 4)", ss.str().c_str());
}
}
{
ngraph::scalar_tree tree{{1, 2}, {3, 4}, 5, {6, 7}};
stringstream ss;
{
ss << tree;
EXPECT_STREQ("((1, 2), (3, 4), 5, (6, 7))", ss.str().c_str());
}
}
{
ngraph::scalar_tree tree{1, {2, {3, 4}, {5, 6}}, 7};
stringstream ss;
{
ss << tree;
EXPECT_STREQ("(1, (2, (3, 4), (5, 6)), 7)", ss.str().c_str());
}
}
}
TEST(strides, sizes_ctor)
{
{
ngraph::tensor_size size{2, 3, 4};
stringstream ss;
{
ss << size;
EXPECT_STREQ("(2, 3, 4)", ss.str().c_str());
EXPECT_EQ(element_type_float, size.get_type());
}
}
{
ngraph::tensor_size size{{2, 3}, 4};
stringstream ss;
{
ss << size;
EXPECT_STREQ("((2, 3), 4)", ss.str().c_str());
EXPECT_EQ(element_type_float, size.get_type());
}
}
{
ngraph::tensor_size size{{1, 2}, {3, 4}, 5, {6, 7}};
stringstream ss;
{
ss << size;
EXPECT_STREQ("((1, 2), (3, 4), 5, (6, 7))", ss.str().c_str());
EXPECT_EQ(element_type_float, size.get_type());
}
}
{
ngraph::tensor_size size{1, {2, {3, 4}, {5, 6}}, 7};
stringstream ss;
{
ss << size;
EXPECT_STREQ("(1, (2, (3, 4), (5, 6)), 7)", ss.str().c_str());
EXPECT_EQ(element_type_float, size.get_type());
}
}
{
ngraph::tensor_size size{{2, 3, 4}, element_type_int32_t};
stringstream ss;
{
ss << size;
EXPECT_STREQ("(2, 3, 4)", ss.str().c_str());
EXPECT_EQ(element_type_int32_t, size.get_type());
}
}
}
TEST(strides, sizes_copy)
{
{
ngraph::tensor_size size{2, 3, 4};
auto copy = size;
stringstream ss;
ss << copy;
EXPECT_STREQ("(2, 3, 4)", ss.str().c_str());
EXPECT_EQ(element_type_float, copy.get_type());
}
{
ngraph::tensor_size size{{2, 3}, 4};
auto copy = size;
stringstream ss;
ss << copy;
EXPECT_STREQ("((2, 3), 4)", ss.str().c_str());
EXPECT_EQ(element_type_float, copy.get_type());
}
{
ngraph::tensor_size size{{1, 2}, {3, 4}, 5, {6, 7}};
auto copy = size;
stringstream ss;
ss << copy;
EXPECT_STREQ("((1, 2), (3, 4), 5, (6, 7))", ss.str().c_str());
EXPECT_EQ(element_type_float, copy.get_type());
}
{
ngraph::tensor_size size{1, {2, {3, 4}, {5, 6}}, 7};
auto copy = size;
stringstream ss;
ss << copy;
EXPECT_STREQ("(1, (2, (3, 4), (5, 6)), 7)", ss.str().c_str());
EXPECT_EQ(element_type_float, copy.get_type());
}
}
TEST(strides, strides)
{
{
ngraph::tensor_size size{2, 3, 4};
{
stringstream ss;
ss << size.strides();
EXPECT_STREQ("(48, 16, 4)", ss.str().c_str());
}
}
{
ngraph::tensor_size size{5, 7, 9, 11};
{
stringstream ss;
ss << size.strides();
EXPECT_STREQ("(2772, 396, 44, 4)", ss.str().c_str());
}
}
{
ngraph::tensor_size size{{5, 7, 9}, 11};
{
stringstream ss;
ss << size.strides();
EXPECT_STREQ("(44, 4)", ss.str().c_str());
}
}
{
ngraph::tensor_size size{{{5, 7}, 9}, 11};
{
stringstream ss;
ss << size.strides();
EXPECT_STREQ("(44, 4)", ss.str().c_str());
}
}
}
TEST(strides, full_strides)
{
{
ngraph::tensor_size size{2, 3, 4};
{
stringstream ss;
ss << size.full_strides();
EXPECT_STREQ("(48, 16, 4)", ss.str().c_str());
}
}
{
ngraph::tensor_size size{5, 7, 9, 11};
{
stringstream ss;
ss << size.full_strides();
EXPECT_STREQ("(2772, 396, 44, 4)", ss.str().c_str());
}
}
{
ngraph::tensor_size size{{5, 7, 9}, 11};
{
stringstream ss;
ss << size.full_strides();
EXPECT_STREQ("((2772, 396, 44), 4)", ss.str().c_str());
}
}
{
ngraph::tensor_size size{{{5, 7}, 9}, 11};
{
stringstream ss;
ss << size.full_strides();
EXPECT_STREQ("(((2772, 396), 44), 4)", ss.str().c_str());
}
}
}
......@@ -12,14 +12,26 @@
# limitations under the License.
set (SRC
log.cpp
descriptor/input.cpp
descriptor/output.cpp
descriptor/tensor_view.cpp
descriptor/tensor.cpp
function.cpp
log.cpp
node.cpp
shape.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/constant.cpp
ops/convert.cpp
ops/dot.cpp
ops/op.cpp
ops/parameter.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
pass/call_pass.cpp
pass/dump_sorted.cpp
......@@ -35,24 +47,11 @@ set (SRC
runtime/call_frame.cpp
runtime/external_function.cpp
shape.cpp
visualize.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/constant.cpp
ops/convert.cpp
ops/dot.cpp
ops/op.cpp
ops/parameter.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
tree.cpp
shape.cpp
types/element_type.cpp
types/type.cpp
util.cpp
visualize.cpp
)
# find_program (GRAPHVIZ dot)
......@@ -70,10 +69,10 @@ include_directories(
SYSTEM
"${EIGEN_INCLUDE_DIR}"
)
add_library(ngraph SHARED ${SRC})
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
if (APPLE)
set_property(TARGET ngraph PROPERTY PREFIX "lib")
set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so")
......@@ -81,7 +80,7 @@ if (APPLE)
else()
include_directories("${MKLDNN_INCLUDE_DIR}")
endif()
#-----------------------------------------------------------------------------------------------
# Installation logic...
#-----------------------------------------------------------------------------------------------
......
#include "ngraph/tree.hpp"
#include "ngraph/util.hpp"
//================================================================================================
//
//================================================================================================
#pragma once
#include <algorithm>
#include <functional>
#include <initializer_list>
#include <iostream>
#include <vector>
#include "ngraph/util.hpp"
namespace ngraph
{
template <typename T>
class tree;
using scalar_tree = ngraph::tree<size_t>;
}
//================================================================================================
//
//================================================================================================
template <typename T>
class ngraph::tree
{
public:
tree(T s)
: m_list{}
, m_value{s}
, m_is_list{false}
{
}
tree(const std::initializer_list<tree<T>>& list)
: m_list{}
, m_value{0}
, m_is_list{true}
{
m_list = list;
}
tree(const std::vector<T>& list)
: m_list{}
, m_value{0}
, m_is_list{true}
{
for (auto s : list)
{
m_list.push_back(tree(s));
}
}
bool is_list() const { return m_is_list; }
T get_value() const { return m_value; }
const std::vector<tree>& get_list() const { return m_list; }
static void traverse_tree(tree& s, std::function<void(T*)> func)
{
if (s.is_list())
{
for (tree& s1 : s.m_list)
{
traverse_tree(s1, func);
}
}
else
{
func(&(s.m_value));
}
}
friend std::ostream& operator<<(std::ostream& out, const tree& s)
{
if (s.is_list())
{
out << "(" << join(s.get_list(), ", ") << ")";
}
else
{
out << s.get_value();
}
return out;
}
T reduce(const std::function<T(T, T)>& func) const
{
size_t rc;
if (is_list())
{
switch (m_list.size())
{
case 0: rc = 0; break;
case 1: rc = m_list[0].reduce(func); break;
default:
rc = m_list[0].reduce(func);
for (int i = 1; i < m_list.size(); i++)
{
rc = func(rc, m_list[i].reduce(func));
}
break;
}
}
else
{
rc = m_value;
}
return rc;
}
private:
std::vector<tree> m_list;
T m_value;
bool m_is_list;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment