Commit 064fb0fc authored by Scott Cyphers's avatar Scott Cyphers

formatting

parent fc0455ba
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include <set> #include <set>
#include <vector>
// Names for types that aren't worth giving their own classes // Names for types that aren't worth giving their own classes
namespace ngraph namespace ngraph
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
class Parameter; class Parameter;
class ValueType; class ValueType;
template<typename T, typename ...A> template <typename T, typename... A>
std::shared_ptr<T> node(A&&... args) std::shared_ptr<T> node(A&&... args)
{ {
return std::make_shared<T>(args...); return std::make_shared<T>(args...);
......
...@@ -43,6 +43,7 @@ namespace ngraph ...@@ -43,6 +43,7 @@ namespace ngraph
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
...@@ -54,15 +55,19 @@ namespace ngraph ...@@ -54,15 +55,19 @@ namespace ngraph
// Provides a compile-time name for a C++ type. // Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation, // Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name. // so it needs to be a valid C++ name.
template<typename T> template <typename T>
const char* traited_type_name() const char* traited_type_name()
{ {
throw ngraph_error("Unknown type"); throw ngraph_error("Unknown type");
} }
// Define a type string for a type T. Will make traited_type_name<T>() return "T" // Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN( T ) \ #define NGRAPH_DEFINE_TTN(T) \
template<> constexpr const char* traited_type_name < T > () { return #T; } template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types // Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides // and element types coordinated. Every element type corresponds to a TraitedType which provides
...@@ -83,31 +88,32 @@ namespace ngraph ...@@ -83,31 +88,32 @@ namespace ngraph
// This is the C++ type used to hold a value of this element type during compilation // This is the C++ type used to hold a value of this element type during compilation
using type = T; using type = T;
// This returns a reference to an instance of this element type. // This returns a reference to an instance of this element type.
static const TraitedType<T>& element_type(){ static const TraitedType<T>& element_type()
{
static TraitedType<T> t; static TraitedType<T> t;
return t; return t;
} }
}; };
NGRAPH_DEFINE_TTN( float ) NGRAPH_DEFINE_TTN(float)
using Float = TraitedType<float>; using Float = TraitedType<float>;
NGRAPH_DEFINE_TTN( int8_t ) NGRAPH_DEFINE_TTN(int8_t)
using Int8 = TraitedType<int8_t>; using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TTN( int32_t ) NGRAPH_DEFINE_TTN(int32_t)
using Int32 = TraitedType<int32_t>; using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TTN( int64_t ) NGRAPH_DEFINE_TTN(int64_t)
using Int64 = TraitedType<int64_t>; using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TTN( uint8_t ) NGRAPH_DEFINE_TTN(uint8_t)
using UInt8 = TraitedType<uint8_t>; using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TTN( uint32_t ) NGRAPH_DEFINE_TTN(uint32_t)
using UInt32 = TraitedType<uint32_t>; using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TTN( uint64_t ) NGRAPH_DEFINE_TTN(uint64_t)
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
std::shared_ptr<Node> result() { return m_result; } std::shared_ptr<Node> result() { return m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; } std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
size_t ngraph::Node::m_next_instance_id = 0; size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, std::shared_ptr<ValueType> type) ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type)
: TypedValueMixin(type) : TypedValueMixin(type)
, m_arguments(arguments) , m_arguments(arguments)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include <iostream> #include <iostream>
#include "type.hpp"
#include "common.hpp" #include "common.hpp"
#include "type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -32,11 +32,11 @@ namespace ngraph ...@@ -32,11 +32,11 @@ namespace ngraph
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node> class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{ {
protected: protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr); Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr);
virtual ~Node() {} virtual ~Node() {}
public: public:
/// A "one-liner" describing this node. /// A "one-liner" describing this node.
virtual std::string description() const = 0; virtual std::string description() const = 0;
......
...@@ -24,39 +24,51 @@ namespace ngraph ...@@ -24,39 +24,51 @@ namespace ngraph
{ {
namespace op namespace op
{ {
std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg); std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0,
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> convert(); //std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution(); //std::shared_ptr<Node> convolution();
std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0,
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0); std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> get_tuple_element(); //std::shared_ptr<Node> get_tuple_element();
std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); //std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); //std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0); std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> logical(); and, or, not //std::shared_ptr<Node> logical(); and, or, not
std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0,
std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0); std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> pad(); //std::shared_ptr<Node> pad();
std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> reduce(); //std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window(); // std::shared_ptr<Node> reduce_window();
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape); std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//std::shared_ptr<Node> reverse(); //std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng(); //std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select(); //std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter(); //std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice(); //std::shared_ptr<Node> slice();
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> transpose(); //std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while(); //std::shared_ptr<Node> while();
} }
...@@ -81,6 +93,7 @@ namespace ngraph ...@@ -81,6 +93,7 @@ namespace ngraph
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
std::shared_ptr<Node> m_function; std::shared_ptr<Node> m_function;
}; };
...@@ -95,6 +108,7 @@ namespace ngraph ...@@ -95,6 +108,7 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now. // TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinOp(const std::vector<std::shared_ptr<Node>>& args) BuiltinOp(const std::vector<std::shared_ptr<Node>>& args)
: Op(args) : Op(args)
......
...@@ -25,7 +25,9 @@ namespace ngraph ...@@ -25,7 +25,9 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
/// ///
BroadcastOp(const std::shared_ptr<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes) BroadcastOp(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
namespace ngraph namespace ngraph
{ {
class ConvertOp : public BuiltinOp class ConvertOp : public BuiltinOp
{ {
public: public:
...@@ -28,13 +27,14 @@ namespace ngraph ...@@ -28,13 +27,14 @@ namespace ngraph
virtual std::string get_op_class_name() const override { return "convert"; } virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
namespace op namespace op
{ {
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type); std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
} }
} }
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
namespace op namespace op
{ {
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
} }
} }
...@@ -51,7 +51,8 @@ namespace ngraph ...@@ -51,7 +51,8 @@ namespace ngraph
namespace op namespace op
{ {
/// Factory for frameworks /// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const std::shared_ptr<ValueType>& value_type = nullptr); std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests /// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type, std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape); const Shape& shape);
......
...@@ -39,6 +39,7 @@ namespace ngraph ...@@ -39,6 +39,7 @@ namespace ngraph
operator const std::vector<size_t>&() const { return m_sizes; } operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; } bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; } bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected: protected:
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
}; };
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -21,7 +21,7 @@ using namespace std; ...@@ -21,7 +21,7 @@ using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n) void ngraph::TopologicalSort::promote_node(Node* n)
{ {
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++) for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{ {
if (dn->first > 0) // Skip zero as they should never be promoted if (dn->first > 0) // Skip zero as they should never be promoted
{ {
...@@ -30,7 +30,7 @@ void ngraph::TopologicalSort::promote_node(Node* n) ...@@ -30,7 +30,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
{ {
// found the node // found the node
dn->second.erase(it); dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n); m_dependent_nodes[dn->first - 1].push_back(n);
} }
} }
} }
...@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n) ...@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void ngraph::TopologicalSort::process(node_ptr p) void ngraph::TopologicalSort::process(node_ptr p)
{ {
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node) {
{
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()]; list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get()); node_list.push_back(node.get());
}); });
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include <memory>
#include <map>
#include <list> #include <list>
#include <map>
#include <memory>
#include <vector>
namespace ngraph namespace ngraph
{ {
......
...@@ -71,7 +71,10 @@ namespace ngraph ...@@ -71,7 +71,10 @@ namespace ngraph
{ {
} }
const std::vector<std::shared_ptr<ValueType>> get_element_types() const { return m_element_types; } const std::vector<std::shared_ptr<ValueType>> get_element_types() const
{
return m_element_types;
}
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; } std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override; virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
...@@ -95,7 +98,10 @@ namespace ngraph ...@@ -95,7 +98,10 @@ namespace ngraph
** Set the type ** Set the type
** /param type The new type ** /param type The new type
**/ **/
void set_value_type(const std::shared_ptr<ValueType>& value_type) { m_value_type = value_type; } void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
/** /**
** Set the type to be a tensor view type ** Set the type to be a tensor view type
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
...@@ -114,6 +120,7 @@ namespace ngraph ...@@ -114,6 +120,7 @@ namespace ngraph
** The type associated with this value. ** The type associated with this value.
**/ **/
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; } const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
protected: protected:
std::shared_ptr<ValueType> m_value_type; std::shared_ptr<ValueType> m_value_type;
}; };
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio> #include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "util.hpp" #include "util.hpp"
#include "visualize.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name) ...@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p) void Visualize::add(node_ptr p)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node) {
{
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n"; m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
...@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p) ...@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p)
void Visualize::save_dot(const string& path) const void Visualize::save_dot(const string& path) const
{ {
auto tmp_file = path+".tmp"; auto tmp_file = path + ".tmp";
ofstream out(tmp_file); ofstream out(tmp_file);
if (out) if (out)
{ {
......
...@@ -24,7 +24,8 @@ void ConvertOp::propagate_types() ...@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg, const element::Type& element_type) shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{ {
return make_shared<ConvertOp>(arg, element_type); return make_shared<ConvertOp>(arg, element_type);
} }
...@@ -20,15 +20,18 @@ using namespace std; ...@@ -20,15 +20,18 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DotOp>(arg0, arg1); return make_shared<DotOp>(arg0, arg1);
} }
void DotOp::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type()); auto arg0_tensor_type =
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type()); dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type) if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{ {
throw ngraph_error("Arguments to dot must be tensor views"); throw ngraph_error("Arguments to dot must be tensor views");
......
...@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg) ...@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
return make_shared<AbsOp>(arg); return make_shared<AbsOp>(arg);
} }
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<AddOp>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<CeilingOp>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
...@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con ...@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con
// 'convert', // 'convert',
// 'convolution', // 'convolution',
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DivideOp>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
...@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0) ...@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
return make_shared<ExpOp>(arg0); return make_shared<ExpOp>(arg0);
} }
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<FloorOp>(arg0, arg1); return make_shared<FloorOp>(arg0, arg1);
} }
...@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0) ...@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
return make_shared<LogOp>(arg0); return make_shared<LogOp>(arg0);
} }
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MaximumOp>(arg0, arg1); return make_shared<MaximumOp>(arg0, arg1);
} }
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MinimumOp>(arg0, arg1); return make_shared<MinimumOp>(arg0, arg1);
} }
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MultiplyOp>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
...@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0) ...@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
// 'pad', // 'pad',
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<PowerOp>(arg0, arg1); return make_shared<PowerOp>(arg0, arg1);
} }
//'reduce', //'reduce',
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<RemainderOp>(arg0, arg1); return make_shared<RemainderOp>(arg0, arg1);
} }
...@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con ...@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con
// 'select', // 'select',
//'slice', //'slice',
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{ {
return make_shared<SubtractOp>(arg0, arg1); return make_shared<SubtractOp>(arg0, arg1);
} }
......
...@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index; m_index = index;
} }
void Parameter::propagate_types() void Parameter::propagate_types() {}
{
}
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type) shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{ {
......
...@@ -99,6 +99,4 @@ TEST(build_graph, literal) ...@@ -99,6 +99,4 @@ TEST(build_graph, literal)
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) TEST(build_graph, arg_inverse) {}
{
}
...@@ -29,7 +29,7 @@ using namespace ngraph; ...@@ -29,7 +29,7 @@ using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes) static bool validate_list(const vector<Node*>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it=nodes.rbegin(); it!=nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{ {
auto node_tmp = *it; auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments(); auto dependencies_tmp = node_tmp->get_arguments();
...@@ -38,8 +38,8 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -38,8 +38,8 @@ static bool validate_list(const vector<Node*>& nodes)
{ {
dependencies.push_back(n.get()); dependencies.push_back(n.get());
} }
auto tmp = it+1; auto tmp = it + 1;
for (; tmp!=nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp); auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
...@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic) TEST(topological_sort, basic)
{ {
vector<shared_ptr<Parameter>> args; vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++) for (int i = 0; i < 10; i++)
{ {
auto arg = op::parameter(element::Float::element_type(), {1}); auto arg = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg); ASSERT_NE(nullptr, arg);
......
...@@ -134,9 +134,7 @@ TEST(util, contains) ...@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8)); EXPECT_FALSE(contains(v1, 8));
} }
TEST(util, remove_from) TEST(util, remove_from) {}
{
}
TEST(util, reduce) TEST(util, reduce)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment