Commit 064fb0fc authored by Scott Cyphers's avatar Scott Cyphers

formatting

parent fc0455ba
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include <set> #include <set>
#include <vector>
// Names for types that aren't worth giving their own classes // Names for types that aren't worth giving their own classes
namespace ngraph namespace ngraph
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
class Parameter; class Parameter;
class ValueType; class ValueType;
template<typename T, typename ...A> template <typename T, typename... A>
std::shared_ptr<T> node(A&&... args) std::shared_ptr<T> node(A&&... args)
{ {
return std::make_shared<T>(args...); return std::make_shared<T>(args...);
...@@ -33,13 +33,13 @@ namespace ngraph ...@@ -33,13 +33,13 @@ namespace ngraph
/// Zero or more value types /// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>; using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
/// Zero or more nodes /// Zero or more nodes
using Nodes = std::vector<std::shared_ptr<Node>>; using Nodes = std::vector<std::shared_ptr<Node>>;
/// A sequence of axes /// A sequence of axes
using AxisVector = std::vector<size_t>; using AxisVector = std::vector<size_t>;
/// A set of axes, for example, reduction axes /// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>; using AxisSet = std::set<size_t>;
......
...@@ -43,27 +43,32 @@ namespace ngraph ...@@ -43,27 +43,32 @@ namespace ngraph
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string m_cname;
}; };
// Provides a compile-time name for a C++ type. // Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation, // Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name. // so it needs to be a valid C++ name.
template<typename T> template <typename T>
const char* traited_type_name() const char* traited_type_name()
{ {
throw ngraph_error("Unknown type"); throw ngraph_error("Unknown type");
} }
// Define a type string for a type T. Will make traited_type_name<T>() return "T" // Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN( T ) \ #define NGRAPH_DEFINE_TTN(T) \
template<> constexpr const char* traited_type_name < T > () { return #T; } template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types // Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides // and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation. // access to both the instance and the C++ type used to hold the value during compilation.
...@@ -72,10 +77,10 @@ namespace ngraph ...@@ -72,10 +77,10 @@ namespace ngraph
{ {
protected: protected:
TraitedType() TraitedType()
: Type(sizeof(T) * 8, : Type(sizeof(T) * 8,
std::is_floating_point<T>::value, std::is_floating_point<T>::value,
std::is_signed<T>::value, std::is_signed<T>::value,
traited_type_name<T>()) traited_type_name<T>())
{ {
} }
...@@ -83,31 +88,32 @@ namespace ngraph ...@@ -83,31 +88,32 @@ namespace ngraph
// This is the C++ type used to hold a value of this element type during compilation // This is the C++ type used to hold a value of this element type during compilation
using type = T; using type = T;
// This returns a reference to an instance of this element type. // This returns a reference to an instance of this element type.
static const TraitedType<T>& element_type(){ static const TraitedType<T>& element_type()
{
static TraitedType<T> t; static TraitedType<T> t;
return t; return t;
} }
}; };
NGRAPH_DEFINE_TTN( float ) NGRAPH_DEFINE_TTN(float)
using Float = TraitedType<float>; using Float = TraitedType<float>;
NGRAPH_DEFINE_TTN( int8_t ) NGRAPH_DEFINE_TTN(int8_t)
using Int8 = TraitedType<int8_t>; using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TTN( int32_t ) NGRAPH_DEFINE_TTN(int32_t)
using Int32 = TraitedType<int32_t>; using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TTN( int64_t ) NGRAPH_DEFINE_TTN(int64_t)
using Int64 = TraitedType<int64_t>; using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TTN( uint8_t ) NGRAPH_DEFINE_TTN(uint8_t)
using UInt8 = TraitedType<uint8_t>; using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TTN( uint32_t ) NGRAPH_DEFINE_TTN(uint32_t)
using UInt32 = TraitedType<uint32_t>; using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TTN( uint64_t ) NGRAPH_DEFINE_TTN(uint64_t)
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
...@@ -25,14 +25,15 @@ namespace ngraph ...@@ -25,14 +25,15 @@ namespace ngraph
class Function class Function
{ {
public: public:
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Node> result() { return m_result; } std::shared_ptr<Node> result() { return m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; } std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
...@@ -40,10 +41,10 @@ namespace ngraph ...@@ -40,10 +41,10 @@ namespace ngraph
namespace op namespace op
{ {
std::shared_ptr<Function> std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result, function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters); const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function> std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result, function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
} }
} }
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
size_t ngraph::Node::m_next_instance_id = 0; size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, std::shared_ptr<ValueType> type) ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type)
: TypedValueMixin(type) : TypedValueMixin(type)
, m_arguments(arguments) , m_arguments(arguments)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include <iostream> #include <iostream>
#include "type.hpp"
#include "common.hpp" #include "common.hpp"
#include "type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -32,11 +32,11 @@ namespace ngraph ...@@ -32,11 +32,11 @@ namespace ngraph
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node> class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{ {
protected: protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr); Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr);
virtual ~Node() {} virtual ~Node() {}
public: public:
/// A "one-liner" describing this node. /// A "one-liner" describing this node.
virtual std::string description() const = 0; virtual std::string description() const = 0;
...@@ -68,10 +68,10 @@ namespace ngraph ...@@ -68,10 +68,10 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
}; };
} }
...@@ -24,39 +24,51 @@ namespace ngraph ...@@ -24,39 +24,51 @@ namespace ngraph
{ {
namespace op namespace op
{ {
std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg); std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0,
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> convert(); //std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution(); //std::shared_ptr<Node> convolution();
std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0,
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0); std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> get_tuple_element(); //std::shared_ptr<Node> get_tuple_element();
std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); //std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); //std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0); std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> logical(); and, or, not //std::shared_ptr<Node> logical(); and, or, not
std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0,
std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0); std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> pad(); //std::shared_ptr<Node> pad();
std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> reduce(); //std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window(); // std::shared_ptr<Node> reduce_window();
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape); std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//std::shared_ptr<Node> reverse(); //std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng(); //std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select(); //std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter(); //std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice(); //std::shared_ptr<Node> slice();
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> transpose(); //std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while(); //std::shared_ptr<Node> while();
} }
...@@ -81,6 +93,7 @@ namespace ngraph ...@@ -81,6 +93,7 @@ namespace ngraph
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
std::shared_ptr<Node> m_function; std::shared_ptr<Node> m_function;
}; };
...@@ -95,6 +108,7 @@ namespace ngraph ...@@ -95,6 +108,7 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now. // TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinOp(const std::vector<std::shared_ptr<Node>>& args) BuiltinOp(const std::vector<std::shared_ptr<Node>>& args)
: Op(args) : Op(args)
......
...@@ -25,7 +25,9 @@ namespace ngraph ...@@ -25,7 +25,9 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
/// ///
BroadcastOp(const std::shared_ptr<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes) BroadcastOp(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
...@@ -36,14 +38,14 @@ namespace ngraph ...@@ -36,14 +38,14 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
AxisSet m_broadcast_axes; AxisSet m_broadcast_axes;
}; };
namespace op namespace op
{ {
std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor, std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape, const Shape& shape,
AxisSet&& broadcast_axes); AxisSet&& broadcast_axes);
} }
} }
...@@ -56,7 +56,7 @@ namespace ngraph ...@@ -56,7 +56,7 @@ namespace ngraph
ss << description() << "_" /* << node_id() */; ss << description() << "_" /* << node_id() */;
return ss.str(); return ss.str();
} }
typename T::type get_value() const { return m_value; } typename T::type get_value() const { return m_value; }
protected: protected:
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
namespace ngraph namespace ngraph
{ {
class ConvertOp : public BuiltinOp class ConvertOp : public BuiltinOp
{ {
public: public:
...@@ -28,13 +27,14 @@ namespace ngraph ...@@ -28,13 +27,14 @@ namespace ngraph
virtual std::string get_op_class_name() const override { return "convert"; } virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
namespace op namespace op
{ {
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type); std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
} }
} }
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
namespace op namespace op
{ {
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
} }
} }
...@@ -51,9 +51,10 @@ namespace ngraph ...@@ -51,9 +51,10 @@ namespace ngraph
namespace op namespace op
{ {
/// Factory for frameworks /// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const std::shared_ptr<ValueType>& value_type = nullptr); std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests /// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type, std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape); const Shape& shape);
} }
} }
...@@ -36,9 +36,10 @@ namespace ngraph ...@@ -36,9 +36,10 @@ namespace ngraph
} }
/// Conversion to a vector of sizes. /// Conversion to a vector of sizes.
operator const std::vector<size_t>&() const { return m_sizes; } operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; } bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; } bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected: protected:
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
}; };
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -21,16 +21,16 @@ using namespace std; ...@@ -21,16 +21,16 @@ using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n) void ngraph::TopologicalSort::promote_node(Node* n)
{ {
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++) for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{ {
if (dn->first > 0) // Skip zero as they should never be promoted if (dn->first > 0) // Skip zero as they should never be promoted
{ {
auto it = find(dn->second.begin(), dn->second.end(), n); auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end()) if (it != dn->second.end())
{ {
// found the node // found the node
dn->second.erase(it); dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n); m_dependent_nodes[dn->first - 1].push_back(n);
} }
} }
} }
...@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n) ...@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void ngraph::TopologicalSort::process(node_ptr p) void ngraph::TopologicalSort::process(node_ptr p)
{ {
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node) {
{
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()]; list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get()); node_list.push_back(node.get());
}); });
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include <memory>
#include <map>
#include <list> #include <list>
#include <map>
#include <memory>
#include <vector>
namespace ngraph namespace ngraph
{ {
...@@ -30,7 +31,7 @@ class ngraph::TopologicalSort ...@@ -30,7 +31,7 @@ class ngraph::TopologicalSort
public: public:
TopologicalSort() {} TopologicalSort() {}
void process(node_ptr); void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const; const std::vector<Node*>& get_sorted_list() const;
private: private:
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
public: public:
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const std::shared_ptr<ValueType>& that) const = 0; virtual bool operator==(const std::shared_ptr<ValueType>& that) const = 0;
bool operator!=(const std::shared_ptr<ValueType>& that) const { return !(*this == that); } bool operator!=(const std::shared_ptr<ValueType>& that) const { return !(*this == that); }
}; };
/// Describes a tensor view; an element type and a shape. /// Describes a tensor view; an element type and a shape.
...@@ -71,8 +71,11 @@ namespace ngraph ...@@ -71,8 +71,11 @@ namespace ngraph
{ {
} }
const std::vector<std::shared_ptr<ValueType>> get_element_types() const { return m_element_types; } const std::vector<std::shared_ptr<ValueType>> get_element_types() const
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; } {
return m_element_types;
}
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override; virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
...@@ -95,7 +98,10 @@ namespace ngraph ...@@ -95,7 +98,10 @@ namespace ngraph
** Set the type ** Set the type
** /param type The new type ** /param type The new type
**/ **/
void set_value_type(const std::shared_ptr<ValueType>& value_type) { m_value_type = value_type; } void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
/** /**
** Set the type to be a tensor view type ** Set the type to be a tensor view type
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
...@@ -114,6 +120,7 @@ namespace ngraph ...@@ -114,6 +120,7 @@ namespace ngraph
** The type associated with this value. ** The type associated with this value.
**/ **/
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; } const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
protected: protected:
std::shared_ptr<ValueType> m_value_type; std::shared_ptr<ValueType> m_value_type;
}; };
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio> #include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "util.hpp" #include "util.hpp"
#include "visualize.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name) ...@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p) void Visualize::add(node_ptr p)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node) {
{
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n"; m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
...@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p) ...@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p)
void Visualize::save_dot(const string& path) const void Visualize::save_dot(const string& path) const
{ {
auto tmp_file = path+".tmp"; auto tmp_file = path + ".tmp";
ofstream out(tmp_file); ofstream out(tmp_file);
if (out) if (out)
{ {
...@@ -53,7 +52,7 @@ void Visualize::save_dot(const string& path) const ...@@ -53,7 +52,7 @@ void Visualize::save_dot(const string& path) const
stringstream ss; stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path; ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str(); auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r"); auto stream = popen(cmd.c_str(), "r");
pclose(stream); pclose(stream);
......
...@@ -21,9 +21,9 @@ using namespace ngraph; ...@@ -21,9 +21,9 @@ using namespace ngraph;
/// @param shape The shape of the result /// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor, std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape, const Shape& shape,
AxisSet&& broadcast_axes) AxisSet&& broadcast_axes)
{ {
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes); return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
} }
......
...@@ -24,7 +24,8 @@ void ConvertOp::propagate_types() ...@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg, const element::Type& element_type) shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{ {
return make_shared<ConvertOp>(arg, element_type); return make_shared<ConvertOp>(arg, element_type);
} }
...@@ -20,15 +20,18 @@ using namespace std; ...@@ -20,15 +20,18 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DotOp>(arg0, arg1); return make_shared<DotOp>(arg0, arg1);
} }
void DotOp::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type()); auto arg0_tensor_type =
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type()); dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type) if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{ {
throw ngraph_error("Arguments to dot must be tensor views"); throw ngraph_error("Arguments to dot must be tensor views");
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const std::shared_ptr<Node>& result, Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters) const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
...@@ -30,13 +30,13 @@ Function::Function(const std::shared_ptr<Node>& ...@@ -30,13 +30,13 @@ Function::Function(const std::shared_ptr<Node>&
} }
} }
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result, shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters) const initializer_list<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
} }
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result, shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters) const vector<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
......
...@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg) ...@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
return make_shared<AbsOp>(arg); return make_shared<AbsOp>(arg);
} }
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<AddOp>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<CeilingOp>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
...@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con ...@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con
// 'convert', // 'convert',
// 'convolution', // 'convolution',
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DivideOp>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
...@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0) ...@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
return make_shared<ExpOp>(arg0); return make_shared<ExpOp>(arg0);
} }
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<FloorOp>(arg0, arg1); return make_shared<FloorOp>(arg0, arg1);
} }
...@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0) ...@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
return make_shared<LogOp>(arg0); return make_shared<LogOp>(arg0);
} }
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MaximumOp>(arg0, arg1); return make_shared<MaximumOp>(arg0, arg1);
} }
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MinimumOp>(arg0, arg1); return make_shared<MinimumOp>(arg0, arg1);
} }
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MultiplyOp>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
...@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0) ...@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
// 'pad', // 'pad',
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<PowerOp>(arg0, arg1); return make_shared<PowerOp>(arg0, arg1);
} }
//'reduce', //'reduce',
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<RemainderOp>(arg0, arg1); return make_shared<RemainderOp>(arg0, arg1);
} }
...@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con ...@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con
// 'select', // 'select',
//'slice', //'slice',
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<SubtractOp>(arg0, arg1); return make_shared<SubtractOp>(arg0, arg1);
} }
......
...@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index; m_index = index;
} }
void Parameter::propagate_types() void Parameter::propagate_types() {}
{
}
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type) shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{ {
......
...@@ -28,7 +28,7 @@ TEST(build_graph, build_simple) ...@@ -28,7 +28,7 @@ TEST(build_graph, build_simple)
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32, 7}); auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float::element_type(), Shape{32, 7}); auto arg3 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0); auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
...@@ -50,7 +50,7 @@ TEST(build_graph, as_type) ...@@ -50,7 +50,7 @@ TEST(build_graph, as_type)
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt}); auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt); auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv); ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt); auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
ASSERT_EQ(tp_vt, tp_tp); ASSERT_EQ(tp_vt, tp_tp);
...@@ -78,8 +78,8 @@ TEST(build_graph, literal) ...@@ -78,8 +78,8 @@ TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0); auto float0 = node<FloatScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0); ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type); ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0); auto d = node<DotOp>(float0, float0);
...@@ -90,15 +90,13 @@ TEST(build_graph, literal) ...@@ -90,15 +90,13 @@ TEST(build_graph, literal)
auto float1 = node<FloatScalarConstant>(3); auto float1 = node<FloatScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type); ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0); auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3); ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), float_scalar_type); ASSERT_NE(*int32_0->get_value_type(), float_scalar_type);
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) TEST(build_graph, arg_inverse) {}
{
}
...@@ -29,20 +29,20 @@ using namespace ngraph; ...@@ -29,20 +29,20 @@ using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes) static bool validate_list(const vector<Node*>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it=nodes.rbegin(); it!=nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{ {
auto node_tmp = *it; auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments(); auto dependencies_tmp = node_tmp->get_arguments();
vector<Node*> dependencies; vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp) for (shared_ptr<Node> n : dependencies_tmp)
{ {
dependencies.push_back(n.get()); dependencies.push_back(n.get());
} }
auto tmp = it+1; auto tmp = it + 1;
for (; tmp!=nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp); auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
if (found != dependencies.end()) if (found != dependencies.end())
{ {
dependencies.erase(found); dependencies.erase(found);
...@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic) TEST(topological_sort, basic)
{ {
vector<shared_ptr<Parameter>> args; vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++) for (int i = 0; i < 10; i++)
{ {
auto arg = op::parameter(element::Float::element_type(), {1}); auto arg = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg); ASSERT_NE(nullptr, arg);
......
...@@ -134,9 +134,7 @@ TEST(util, contains) ...@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8)); EXPECT_FALSE(contains(v1, 8));
} }
TEST(util, remove_from) TEST(util, remove_from) {}
{
}
TEST(util, reduce) TEST(util, reduce)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment