Commit 92c4d314 authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 'master' into cyphers/names

parents f106a582 6a4225c0
...@@ -29,6 +29,9 @@ set (SRC ...@@ -29,6 +29,9 @@ set (SRC
ops/tuple.cpp ops/tuple.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
ngraph/node.cpp
ngraph/topological_sort.cpp
ngraph/visualize.cpp
) )
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled # NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
...@@ -37,7 +37,6 @@ namespace nervana ...@@ -37,7 +37,6 @@ namespace nervana
} }
constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; } constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; }
constexpr size_t size() const { return _size; } constexpr size_t size() const { return _size; }
private: private:
const char* _string; const char* _string;
size_t _size; size_t _size;
...@@ -45,9 +44,8 @@ namespace nervana ...@@ -45,9 +44,8 @@ namespace nervana
constexpr const char* find_last(conststring s, size_t offset, char ch) constexpr const char* find_last(conststring s, size_t offset, char ch)
{ {
return offset == 0 return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1)
? s.get_ptr(0) : find_last(s, offset - 1, ch));
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
} }
constexpr const char* find_last(conststring s, char ch) constexpr const char* find_last(conststring s, char ch)
...@@ -69,7 +67,6 @@ namespace nervana ...@@ -69,7 +67,6 @@ namespace nervana
~log_helper(); ~log_helper();
std::ostream& stream() { return _stream; } std::ostream& stream() { return _stream; }
private: private:
std::stringstream _stream; std::stringstream _stream;
}; };
......
...@@ -39,9 +39,8 @@ namespace ngraph ...@@ -39,9 +39,8 @@ namespace ngraph
return h(m_cname); return h(m_cname);
} }
//bool operator==(const Type& other) const; bool operator==(const Type& other) const;
//bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
......
...@@ -31,11 +31,8 @@ namespace ngraph ...@@ -31,11 +31,8 @@ namespace ngraph
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
Node::ptr result() { return m_result; } Node::ptr result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; } Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
protected: protected:
Node::ptr m_result; Node::ptr m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "op.hpp"
size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
: TypedValueMixin(type)
, m_arguments(arguments)
, m_instance_id(m_next_instance_id++)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
{
node->m_users.insert(node.get());
}
}
bool ngraph::Node::is_op() const
{
return dynamic_cast<const ngraph::Op*>(this) != nullptr;
}
bool ngraph::Node::is_parameter() const
{
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr;
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
{
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node);
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
if (op_tmp)
{
out << "Op(" << op_tmp->node_id() << ")";
}
else if (parameter_tmp)
{
out << "Parameter(" << parameter_tmp->node_id() << ")";
}
else
{
out << "Node(" << node.node_id() << ")";
}
return out;
}
...@@ -38,19 +38,9 @@ namespace ngraph ...@@ -38,19 +38,9 @@ namespace ngraph
using ptr = std::shared_ptr<Node>; using ptr = std::shared_ptr<Node>;
protected: protected:
Node(const Nodes& arguments, ValueType::ptr type = nullptr) Node(const Nodes& arguments, ValueType::ptr type = nullptr);
: TypedValueMixin(type)
, m_arguments(arguments)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
{
node->m_users.insert(node.get());
}
}
virtual ~Node() {} virtual ~Node() {}
public: public:
/// A "one-liner" describing this node. /// A "one-liner" describing this node.
virtual std::string description() const = 0; virtual std::string description() const = 0;
...@@ -65,6 +55,8 @@ namespace ngraph ...@@ -65,6 +55,8 @@ namespace ngraph
std::string name() const { return m_name; } std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; } void name(const std::string& name) { m_name = name; }
virtual std::string node_id() const = 0;
/** /**
** Return true if this has the same implementing class as node. This ** Return true if this has the same implementing class as node. This
** will be used by the pattern matcher when comparing a pattern ** will be used by the pattern matcher when comparing a pattern
...@@ -75,9 +67,19 @@ namespace ngraph ...@@ -75,9 +67,19 @@ namespace ngraph
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
} }
bool is_op() const;
bool is_parameter() const;
size_t instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
protected: protected:
Nodes m_arguments; Nodes m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name; std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
}; };
using node_ptr = std::shared_ptr<Node>;
} }
...@@ -73,6 +73,9 @@ namespace ngraph ...@@ -73,6 +73,9 @@ namespace ngraph
: Node(arguments, nullptr) : Node(arguments, nullptr)
{ {
} }
virtual std::string op_class_name() const = 0;
virtual std::string node_id() const;
}; };
/** /**
...@@ -82,7 +85,6 @@ namespace ngraph ...@@ -82,7 +85,6 @@ namespace ngraph
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
Node::ptr m_function; Node::ptr m_function;
}; };
...@@ -96,11 +98,9 @@ namespace ngraph ...@@ -96,11 +98,9 @@ namespace ngraph
public: public:
virtual std::string description() const override { return "BuiltinOp"; } virtual std::string description() const override { return "BuiltinOp"; }
/// Name of the builtin op, for debugging and logging. /// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
// TODO: Implement for each op. This enables graphs to be built for now. // TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinOp(const std::vector<Node::ptr>& args) BuiltinOp(const std::vector<Node::ptr>& args)
: Op(args) : Op(args)
...@@ -116,7 +116,7 @@ namespace ngraph ...@@ -116,7 +116,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "abs"; } virtual std::string op_class_name() const override { return "abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -127,7 +127,7 @@ namespace ngraph ...@@ -127,7 +127,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "add"; } virtual std::string op_class_name() const override { return "add"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -139,7 +139,7 @@ namespace ngraph ...@@ -139,7 +139,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "ceiling"; } virtual std::string op_class_name() const override { return "ceiling"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -151,7 +151,7 @@ namespace ngraph ...@@ -151,7 +151,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "divide"; } virtual std::string op_class_name() const override { return "divide"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -163,7 +163,7 @@ namespace ngraph ...@@ -163,7 +163,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "equal"; } virtual std::string op_class_name() const override { return "equal"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -175,7 +175,7 @@ namespace ngraph ...@@ -175,7 +175,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "exp"; } virtual std::string op_class_name() const override { return "exp"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -187,7 +187,7 @@ namespace ngraph ...@@ -187,7 +187,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "floor"; } virtual std::string op_class_name() const override { return "floor"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -199,7 +199,7 @@ namespace ngraph ...@@ -199,7 +199,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "greater"; } virtual std::string op_class_name() const override { return "greater"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -211,7 +211,7 @@ namespace ngraph ...@@ -211,7 +211,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "less"; } virtual std::string op_class_name() const override { return "less"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -223,7 +223,7 @@ namespace ngraph ...@@ -223,7 +223,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "log"; } virtual std::string op_class_name() const override { return "log"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -235,7 +235,7 @@ namespace ngraph ...@@ -235,7 +235,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "max"; } virtual std::string op_class_name() const override { return "max"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -247,7 +247,7 @@ namespace ngraph ...@@ -247,7 +247,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "min"; } virtual std::string op_class_name() const override { return "min"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -259,7 +259,7 @@ namespace ngraph ...@@ -259,7 +259,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "multiply"; } virtual std::string op_class_name() const override { return "multiply"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -271,7 +271,7 @@ namespace ngraph ...@@ -271,7 +271,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "negative"; } virtual std::string op_class_name() const override { return "negative"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -283,7 +283,7 @@ namespace ngraph ...@@ -283,7 +283,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "power"; } virtual std::string op_class_name() const override { return "power"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -295,7 +295,7 @@ namespace ngraph ...@@ -295,7 +295,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "remainder"; } virtual std::string op_class_name() const override { return "remainder"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -308,7 +308,7 @@ namespace ngraph ...@@ -308,7 +308,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "reshape"; } virtual std::string op_class_name() const override { return "reshape"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
...@@ -322,7 +322,7 @@ namespace ngraph ...@@ -322,7 +322,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "subtract"; } virtual std::string op_class_name() const override { return "subtract"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
} }
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "broadcast"; } virtual std::string op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "concatenate"; } virtual std::string op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <sstream>
#include "../element_type.hpp" #include "../element_type.hpp"
namespace ngraph namespace ngraph
...@@ -48,6 +50,12 @@ namespace ngraph ...@@ -48,6 +50,12 @@ namespace ngraph
} }
virtual std::string description() const override { return "ConstantScalar"; } virtual std::string description() const override { return "ConstantScalar"; }
virtual std::string node_id() const override
{
std::stringstream ss;
ss << description() << "_" << node_id();
return ss.str();
}
typename T::ctype value() const { return m_value; } typename T::ctype value() const { return m_value; }
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "convert"; } virtual std::string op_class_name() const override { return "convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "dot"; } virtual std::string op_class_name() const override { return "dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
......
...@@ -39,8 +39,8 @@ namespace ngraph ...@@ -39,8 +39,8 @@ namespace ngraph
Parameter(const ValueType::ptr& value_type); Parameter(const ValueType::ptr& value_type);
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string node_id() const override;
protected: protected:
Function* m_function; Function* m_function;
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "tuple"; } virtual std::string op_class_name() const override { return "tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
...@@ -41,10 +41,8 @@ namespace ngraph ...@@ -41,10 +41,8 @@ namespace ngraph
** Conversion to a vector of sizes. ** Conversion to a vector of sizes.
**/ **/
operator const std::vector<size_t>&() const { return m_sizes; } operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; } bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; } bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected: protected:
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "topological_sort.hpp"
void ngraph::TopologicalSort::process(node_ptr node)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
namespace ngraph
{
class TopologicalSort;
class Node;
using node_ptr = std::shared_ptr<Node>;
}
class ngraph::TopologicalSort
{
public:
TopologicalSort();
static void process(node_ptr);
private:
};
...@@ -89,7 +89,6 @@ namespace ngraph ...@@ -89,7 +89,6 @@ namespace ngraph
** Construct empty tuple and add value types later. ** Construct empty tuple and add value types later.
**/ **/
TupleType() {} TupleType() {}
/** /**
** /param element_types A vector of types for the tuple elements ** /param element_types A vector of types for the tuple elements
**/ **/
...@@ -123,7 +122,6 @@ namespace ngraph ...@@ -123,7 +122,6 @@ namespace ngraph
** /param type The new type ** /param type The new type
**/ **/
void type(const ValueType::ptr& type) { m_type = type; } void type(const ValueType::ptr& type) { m_type = type; }
/** /**
** Set the type to be a tensor view type ** Set the type to be a tensor view type
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
...@@ -138,12 +136,10 @@ namespace ngraph ...@@ -138,12 +136,10 @@ namespace ngraph
** The type associated with this value. ** The type associated with this value.
**/ **/
ValueType::ptr type() { return m_type; } ValueType::ptr type() { return m_type; }
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
const ValueType::ptr type() const { return m_type; } const ValueType::ptr type() const { return m_type; }
protected: protected:
ValueType::ptr m_type; ValueType::ptr m_type;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio>
#include "visualize.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
using namespace ngraph;
using namespace std;
Visualize::Visualize(const string& name)
: m_name{name}
{
}
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node)
{
for (auto arg : node->arguments())
{
m_ss << " " << arg->node_id() << " -> " << node->node_id() << ";\n";
}
});
}
void Visualize::save_dot(const string& path) const
{
auto tmp_file = path+".tmp";
ofstream out(tmp_file);
if (out)
{
out << "digraph " << m_name << "\n{\n";
out << m_ss.str();
out << "}\n";
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
// remove(tmp_file.c_str());
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
namespace ngraph
{
class Visualize;
class Node;
using node_ptr = std::shared_ptr<Node>;
}
class ngraph::Visualize
{
public:
Visualize(const std::string& name = "ngraph");
void add(node_ptr);
void save_dot(const std::string& path) const;
private:
std::stringstream m_ss;
std::string m_name;
};
...@@ -13,12 +13,20 @@ ...@@ -13,12 +13,20 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <algorithm> #include <algorithm>
#include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::string ngraph::Op::node_id() const
{
stringstream ss;
ss << op_class_name() << "_" << m_instance_id;
return ss.str();
}
Node::ptr ngraph::op::abs(const Node::ptr& arg) Node::ptr ngraph::op::abs(const Node::ptr& arg)
{ {
return make_shared<AbsOp>(arg); return make_shared<AbsOp>(arg);
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
...@@ -34,7 +36,9 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -34,7 +36,9 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index; m_index = index;
} }
void Parameter::propagate_types() {} void Parameter::propagate_types()
{
}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type) shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type)
{ {
...@@ -46,3 +50,10 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_ ...@@ -46,3 +50,10 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
{ {
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape)); return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
} }
std::string ngraph::Parameter::node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
return ss.str();
}
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <map> #include <map>
#include "util.hpp" #include "util.hpp"
#include "ngraph/node.hpp"
using namespace std; using namespace std;
...@@ -129,3 +130,25 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -129,3 +130,25 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
} }
return seed; return seed;
} }
static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(std::shared_ptr<ngraph::Node>)> f,
std::set<size_t>& instances_seen)
{
f(p);
for (auto arg : p->arguments())
{
if (instances_seen.find(arg->instance_id()) == instances_seen.end())
{
instances_seen.insert(arg->instance_id());
traverse_nodes(arg, f, instances_seen);
}
}
}
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(std::shared_ptr<ngraph::Node>)> f)
{
std::set<size_t> instances_seen;
::traverse_nodes(p, f, instances_seen);
}
...@@ -24,11 +24,12 @@ ...@@ -24,11 +24,12 @@
namespace ngraph namespace ngraph
{ {
class Node;
class stopwatch; class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics; extern std::map<std::string, stopwatch*> stopwatch_statistics;
template <typename T> template <typename T>
std::string join(const T& v, const std::string& sep) std::string join(const T& v, const std::string& sep = ", ")
{ {
std::ostringstream ss; std::ostringstream ss;
for (const auto& x : v) for (const auto& x : v)
...@@ -148,7 +149,6 @@ namespace ngraph ...@@ -148,7 +149,6 @@ namespace ngraph
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; } size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; } size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); } size_t get_total_nanoseconds() const { return m_total_time.count(); }
private: private:
std::chrono::high_resolution_clock m_clock; std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time; std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
...@@ -194,4 +194,5 @@ namespace ngraph ...@@ -194,4 +194,5 @@ namespace ngraph
return a * b; return a * b;
} }
void traverse_nodes(std::shared_ptr<Node> p, std::function<void(std::shared_ptr<Node>)> f);
} // end namespace ngraph } // end namespace ngraph
...@@ -27,6 +27,8 @@ set (SRC ...@@ -27,6 +27,8 @@ set (SRC
tensor.cpp tensor.cpp
element_type.cpp element_type.cpp
uuid.cpp uuid.cpp
topological_sort.cpp
op.cpp
) )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
...@@ -43,7 +43,6 @@ TEST(build_graph, build_simple) ...@@ -43,7 +43,6 @@ TEST(build_graph, build_simple)
auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0}); auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0}); auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
auto d1 = myfun<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
...@@ -115,4 +114,3 @@ TEST(build_graph, literal) ...@@ -115,4 +114,3 @@ TEST(build_graph, literal)
TEST(build_graph, arg_inverse) TEST(build_graph, arg_inverse)
{ {
} }
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <dlfcn.h> #include <dlfcn.h>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph.hpp"
#include "log.hpp" #include "log.hpp"
#include "ngraph.hpp"
using namespace std; using namespace std;
......
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(op, is_op)
{
auto arg0 = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
}
TEST(op, is_parameter)
{
auto arg0 = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0);
ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op());
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp"
#include "ngraph/visualize.hpp"
using namespace std;
using namespace ngraph;
TEST(top_sort, basic)
{
auto arg0 = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg0);
auto arg1 = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg1);
auto t0 = op::add(arg0, arg1);
ASSERT_NE(nullptr, t0);
auto t1 = op::add(arg0, arg1);
ASSERT_NE(nullptr, t1);
Node::ptr r0 = op::add(t0, t1);
ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, {arg0, arg1});
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl;
Visualize vz;
vz.add(r0);
vz.save_dot("test.png");
TopologicalSort::process(r0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment