Commit 6a4225c0 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #71 from NervanaSystems/bob/top4

Add topological sort and viualization classes
parents 5f724e48 c16cc639
......@@ -29,6 +29,9 @@ set (SRC
ops/tuple.cpp
types/element_type.cpp
types/type.cpp
ngraph/node.cpp
ngraph/topological_sort.cpp
ngraph/visualize.cpp
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
......@@ -36,8 +36,7 @@ namespace nervana
return i < _size ? _string[i] : throw std::out_of_range("");
}
constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; }
constexpr size_t size() const { return _size; }
constexpr size_t size() const { return _size; }
private:
const char* _string;
size_t _size;
......@@ -45,9 +44,8 @@ namespace nervana
constexpr const char* find_last(conststring s, size_t offset, char ch)
{
return offset == 0
? s.get_ptr(0)
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1)
: find_last(s, offset - 1, ch));
}
constexpr const char* find_last(conststring s, char ch)
......@@ -69,7 +67,6 @@ namespace nervana
~log_helper();
std::ostream& stream() { return _stream; }
private:
std::stringstream _stream;
};
......@@ -84,9 +81,9 @@ namespace nervana
static void stop();
private:
static void log_item(const std::string& s);
static void process_event(const std::string& s);
static void thread_entry(void* param);
static void log_item(const std::string& s);
static void process_event(const std::string& s);
static void thread_entry(void* param);
static std::string log_path;
static std::deque<std::string> queue;
};
......
......@@ -41,13 +41,12 @@ namespace ngraph
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
private:
static std::map<std::string, Type> m_element_list;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string m_cname;
};
// Literals (and probably other things we don't know about yet) need to have their C++ types
......
......@@ -30,12 +30,9 @@ namespace ngraph
Function(const Node::ptr& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
Node::ptr result() { return m_result; }
Node::ptr result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; }
std::string name() const { return m_name; }
protected:
Node::ptr m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "op.hpp"
size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
: TypedValueMixin(type)
, m_arguments(arguments)
, m_instance_id(m_next_instance_id++)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
{
node->m_users.insert(node.get());
}
}
bool ngraph::Node::is_op() const
{
return dynamic_cast<const ngraph::Op*>(this) != nullptr;
}
bool ngraph::Node::is_parameter() const
{
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr;
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
{
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node);
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
if (op_tmp)
{
out << "Op(" << op_tmp->node_id() << ")";
}
else if (parameter_tmp)
{
out << "Parameter(" << parameter_tmp->node_id() << ")";
}
else
{
out << "Node(" << node.node_id() << ")";
}
return out;
}
......@@ -38,19 +38,9 @@ namespace ngraph
using ptr = std::shared_ptr<Node>;
protected:
Node(const Nodes& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type)
, m_arguments(arguments)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
{
node->m_users.insert(node.get());
}
}
Node(const Nodes& arguments, ValueType::ptr type = nullptr);
virtual ~Node() {}
public:
/// A "one-liner" describing this node.
virtual std::string description() const = 0;
......@@ -65,6 +55,8 @@ namespace ngraph
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
virtual std::string node_id() const = 0;
/**
** Return true if this has the same implementing class as node. This
** will be used by the pattern matcher when comparing a pattern
......@@ -75,9 +67,19 @@ namespace ngraph
return typeid(*this) == typeid(*node.get());
}
bool is_op() const;
bool is_parameter() const;
size_t instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
protected:
Nodes m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
};
using node_ptr = std::shared_ptr<Node>;
}
......@@ -72,6 +72,9 @@ namespace ngraph
: Node(arguments, nullptr)
{
}
virtual std::string op_class_name() const = 0;
virtual std::string node_id() const;
};
/**
......@@ -81,7 +84,6 @@ namespace ngraph
class FunctionOp : public Op
{
virtual std::string description() const override { return "FunctionOp"; }
protected:
Node::ptr m_function;
};
......@@ -95,11 +97,9 @@ namespace ngraph
public:
virtual std::string description() const override { return "BuiltinOp"; }
/// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected:
BuiltinOp(const std::vector<Node::ptr>& args)
: Op(args)
......@@ -115,7 +115,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "abs"; }
virtual std::string op_class_name() const override { return "abs"; }
//virtual void propagate_types() override;
};
......@@ -126,7 +126,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "add"; }
virtual std::string op_class_name() const override { return "add"; }
//virtual void propagate_types() override;
};
......@@ -138,7 +138,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "ceiling"; }
virtual std::string op_class_name() const override { return "ceiling"; }
//virtual void propagate_types() override;
};
......@@ -150,7 +150,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "divide"; }
virtual std::string op_class_name() const override { return "divide"; }
//virtual void propagate_types() override;
};
......@@ -162,7 +162,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "equal"; }
virtual std::string op_class_name() const override { return "equal"; }
//virtual void propagate_types() override;
};
......@@ -174,7 +174,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "exp"; }
virtual std::string op_class_name() const override { return "exp"; }
//virtual void propagate_types() override;
};
......@@ -186,7 +186,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "floor"; }
virtual std::string op_class_name() const override { return "floor"; }
//virtual void propagate_types() override;
};
......@@ -198,7 +198,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "greater"; }
virtual std::string op_class_name() const override { return "greater"; }
//virtual void propagate_types() override;
};
......@@ -210,7 +210,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "less"; }
virtual std::string op_class_name() const override { return "less"; }
//virtual void propagate_types() override;
};
......@@ -222,7 +222,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "log"; }
virtual std::string op_class_name() const override { return "log"; }
//virtual void propagate_types() override;
};
......@@ -234,7 +234,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "max"; }
virtual std::string op_class_name() const override { return "max"; }
//virtual void propagate_types() override;
};
......@@ -246,7 +246,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "min"; }
virtual std::string op_class_name() const override { return "min"; }
//virtual void propagate_types() override;
};
......@@ -258,7 +258,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "multiply"; }
virtual std::string op_class_name() const override { return "multiply"; }
//virtual void propagate_types() override;
};
......@@ -270,7 +270,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "negate"; }
virtual std::string op_class_name() const override { return "negate"; }
//virtual void propagate_types() override;
};
......@@ -282,7 +282,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "power"; }
virtual std::string op_class_name() const override { return "power"; }
//virtual void propagate_types() override;
};
......@@ -294,7 +294,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "remainder"; }
virtual std::string op_class_name() const override { return "remainder"; }
//virtual void propagate_types() override;
};
......@@ -307,7 +307,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "reshape"; }
virtual std::string op_class_name() const override { return "reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
......@@ -321,7 +321,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "subtract"; }
virtual std::string op_class_name() const override { return "subtract"; }
//virtual void propagate_types() override;
};
}
......@@ -34,7 +34,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual std::string op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
......
......@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "concatenate"; }
virtual std::string op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override;
};
}
......@@ -14,6 +14,8 @@
#pragma once
#include <sstream>
#include "../element_type.hpp"
namespace ngraph
......@@ -48,7 +50,13 @@ namespace ngraph
}
virtual std::string description() const override { return "ConstantScalar"; }
virtual std::string node_id() const override
{
std::stringstream ss;
ss << description() << "_" << node_id();
return ss.str();
}
typename T::ctype value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
......
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "convert"; }
virtual std::string op_class_name() const override { return "convert"; }
virtual void propagate_types() override;
protected:
const ngraph::element::Type& m_element_type;
......
......@@ -25,7 +25,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "dot"; }
virtual std::string op_class_name() const override { return "dot"; }
virtual void propagate_types() override;
};
......
......@@ -38,9 +38,9 @@ namespace ngraph
public:
Parameter(const ValueType::ptr& value_type);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::string node_id() const override;
protected:
Function* m_function;
......
......@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual std::string op_name() const override { return "tuple"; }
virtual std::string op_class_name() const override { return "tuple"; }
virtual void propagate_types() override;
};
}
......@@ -41,10 +41,8 @@ namespace ngraph
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected:
std::vector<size_t> m_sizes;
};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "topological_sort.hpp"
void ngraph::TopologicalSort::process(node_ptr node)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
namespace ngraph
{
class TopologicalSort;
class Node;
using node_ptr = std::shared_ptr<Node>;
}
class ngraph::TopologicalSort
{
public:
TopologicalSort();
static void process(node_ptr);
private:
};
......@@ -89,7 +89,6 @@ namespace ngraph
** Construct empty tuple and add value types later.
**/
TupleType() {}
/**
** /param element_types A vector of types for the tuple elements
**/
......@@ -123,7 +122,6 @@ namespace ngraph
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
......@@ -138,12 +136,10 @@ namespace ngraph
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
protected:
ValueType::ptr m_type;
};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio>
#include "visualize.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
using namespace ngraph;
using namespace std;
Visualize::Visualize(const string& name)
: m_name{name}
{
}
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node)
{
for (auto arg : node->arguments())
{
m_ss << " " << arg->node_id() << " -> " << node->node_id() << ";\n";
}
});
}
void Visualize::save_dot(const string& path) const
{
auto tmp_file = path+".tmp";
ofstream out(tmp_file);
if (out)
{
out << "digraph " << m_name << "\n{\n";
out << m_ss.str();
out << "}\n";
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << path;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
// remove(tmp_file.c_str());
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
namespace ngraph
{
class Visualize;
class Node;
using node_ptr = std::shared_ptr<Node>;
}
class ngraph::Visualize
{
public:
Visualize(const std::string& name = "ngraph");
void add(node_ptr);
void save_dot(const std::string& path) const;
private:
std::stringstream m_ss;
std::string m_name;
};
......@@ -13,12 +13,20 @@
// ----------------------------------------------------------------------------
#include <algorithm>
#include <sstream>
#include "ngraph/ngraph.hpp"
using namespace ngraph;
using namespace std;
std::string ngraph::Op::node_id() const
{
stringstream ss;
ss << op_class_name() << "_" << m_instance_id;
return ss.str();
}
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsOp>(arg);
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "ngraph/ngraph.hpp"
using namespace std;
......@@ -34,7 +36,9 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index;
}
void Parameter::propagate_types() {}
void Parameter::propagate_types()
{
}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type)
{
......@@ -46,3 +50,10 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
std::string ngraph::Parameter::node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
return ss.str();
}
......@@ -51,7 +51,7 @@ public:
bool is_list() const { return m_is_list; }
T get_value() const { return m_value; }
const std::vector<tree>& get_list() const { return m_list; }
static void traverse_tree(tree& s, std::function<void(T*)> func)
static void traverse_tree(tree& s, std::function<void(T*)> func)
{
if (s.is_list())
{
......
......@@ -16,6 +16,7 @@
#include <map>
#include "util.hpp"
#include "ngraph/node.hpp"
using namespace std;
......@@ -129,3 +130,25 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
}
return seed;
}
static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(std::shared_ptr<ngraph::Node>)> f,
std::set<size_t>& instances_seen)
{
f(p);
for (auto arg : p->arguments())
{
if (instances_seen.find(arg->instance_id()) == instances_seen.end())
{
instances_seen.insert(arg->instance_id());
traverse_nodes(arg, f, instances_seen);
}
}
}
void ngraph::traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::function<void(std::shared_ptr<ngraph::Node>)> f)
{
std::set<size_t> instances_seen;
::traverse_nodes(p, f, instances_seen);
}
......@@ -24,11 +24,12 @@
namespace ngraph
{
class Node;
class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics;
template <typename T>
std::string join(const T& v, const std::string& sep)
std::string join(const T& v, const std::string& sep = ", ")
{
std::ostringstream ss;
for (const auto& x : v)
......@@ -83,10 +84,10 @@ namespace ngraph
}
size_t hash_combine(const std::vector<size_t>& list);
void dump(std::ostream& out, const void*, size_t);
void dump(std::ostream& out, const void*, size_t);
std::string to_lower(const std::string& s);
std::string trim(const std::string& s);
std::string to_lower(const std::string& s);
std::string trim(const std::string& s);
std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
class stopwatch
......@@ -148,7 +149,6 @@ namespace ngraph
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); }
private:
std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
......@@ -194,4 +194,5 @@ namespace ngraph
return a * b;
}
void traverse_nodes(std::shared_ptr<Node> p, std::function<void(std::shared_ptr<Node>)> f);
} // end namespace ngraph
......@@ -73,7 +73,7 @@ public:
return memcmp((char*)m_data, (char*)other.m_data, 16) == 0;
}
bool operator!=(const uuid_type& other) const { return !(*this == other); }
bool operator!=(const uuid_type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& out, const uuid_type& id)
{
out << id.to_string();
......
......@@ -27,6 +27,8 @@ set (SRC
tensor.cpp
element_type.cpp
uuid.cpp
topological_sort.cpp
op.cpp
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
......@@ -28,6 +28,7 @@ TEST(build_graph, build_simple)
auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(2, dot->arguments().size());
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
......@@ -96,4 +97,6 @@ TEST(build_graph, literal)
}
// Check argument inverses
TEST(build_graph, arg_inverse) {}
TEST(build_graph, arg_inverse)
{
}
......@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <dlfcn.h>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph.hpp"
#include "log.hpp"
#include "ngraph.hpp"
using namespace std;
......
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(op, is_op)
{
auto arg0 = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
}
TEST(op, is_parameter)
{
auto arg0 = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0);
ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op());
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp"
#include "ngraph/visualize.hpp"
using namespace std;
using namespace ngraph;
TEST(top_sort, basic)
{
auto arg0 = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg0);
auto arg1 = op::parameter(element::Float::type, {1});
ASSERT_NE(nullptr, arg1);
auto t0 = op::add(arg0, arg1);
ASSERT_NE(nullptr, t0);
auto t1 = op::add(arg0, arg1);
ASSERT_NE(nullptr, t1);
Node::ptr r0 = op::add(t0, t1);
ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, {arg0, arg1});
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl;
Visualize vz;
vz.add(r0);
vz.save_dot("test.png");
TopologicalSort::process(r0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment