Commit 5817c0a7 authored by Scott Cyphers's avatar Scott Cyphers

Switch to factory functions, style on some more files.

parent 494b16cd
...@@ -36,7 +36,8 @@ namespace nervana ...@@ -36,7 +36,8 @@ namespace nervana
return i < _size ? _string[i] : throw std::out_of_range(""); return i < _size ? _string[i] : throw std::out_of_range("");
} }
constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; } constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; }
constexpr size_t size() const { return _size; } constexpr size_t size() const { return _size; }
private: private:
const char* _string; const char* _string;
size_t _size; size_t _size;
...@@ -44,8 +45,9 @@ namespace nervana ...@@ -44,8 +45,9 @@ namespace nervana
constexpr const char* find_last(conststring s, size_t offset, char ch) constexpr const char* find_last(conststring s, size_t offset, char ch)
{ {
return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1) return offset == 0
: find_last(s, offset - 1, ch)); ? s.get_ptr(0)
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
} }
constexpr const char* find_last(conststring s, char ch) constexpr const char* find_last(conststring s, char ch)
...@@ -67,6 +69,7 @@ namespace nervana ...@@ -67,6 +69,7 @@ namespace nervana
~log_helper(); ~log_helper();
std::ostream& stream() { return _stream; } std::ostream& stream() { return _stream; }
private: private:
std::stringstream _stream; std::stringstream _stream;
}; };
...@@ -81,9 +84,9 @@ namespace nervana ...@@ -81,9 +84,9 @@ namespace nervana
static void stop(); static void stop();
private: private:
static void log_item(const std::string& s); static void log_item(const std::string& s);
static void process_event(const std::string& s); static void process_event(const std::string& s);
static void thread_entry(void* param); static void thread_entry(void* param);
static std::string log_path; static std::string log_path;
static std::deque<std::string> queue; static std::deque<std::string> queue;
}; };
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include <vector>
#include <set> #include <set>
#include <vector>
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -31,7 +31,6 @@ namespace ngraph ...@@ -31,7 +31,6 @@ namespace ngraph
class Node : public TypedValueMixin class Node : public TypedValueMixin
{ {
public: public:
using ptr = std::shared_ptr<Node>; using ptr = std::shared_ptr<Node>;
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr) Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
...@@ -39,7 +38,8 @@ namespace ngraph ...@@ -39,7 +38,8 @@ namespace ngraph
, m_arguments(arguments) , m_arguments(arguments)
{ {
// Add this node as a user of each argument. // Add this node as a user of each argument.
for(auto node : m_arguments){ for (auto node : m_arguments)
{
node->m_users.insert(node.get()); node->m_users.insert(node.get());
} }
} }
...@@ -52,6 +52,6 @@ namespace ngraph ...@@ -52,6 +52,6 @@ namespace ngraph
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
}; };
} }
...@@ -50,8 +50,15 @@ namespace ngraph ...@@ -50,8 +50,15 @@ namespace ngraph
{ {
}; };
namespace op
{
std::shared_ptr<Node> broadcast(const Node::ptr& tensor, size_t axis);
}
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast> class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{ {
friend std::shared_ptr<Node> op::broadcast(const Node::ptr& tensor, size_t axis);
protected: protected:
class BroadcastCall : public Call class BroadcastCall : public Call
{ {
...@@ -68,29 +75,17 @@ namespace ngraph ...@@ -68,29 +75,17 @@ namespace ngraph
size_t m_axis; size_t m_axis;
}; };
public: static std::shared_ptr<Broadcast> s_op;
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
}
}; };
namespace op namespace op
{ {
extern decltype(*std::shared_ptr<Broadcast>()) broadcast; std::shared_ptr<Node> dot(const Node::ptr& arg0, const Node::ptr& arg1);
} }
class Dot : public Op, public std::enable_shared_from_this<Dot> class Dot : public Op, public std::enable_shared_from_this<Dot>
{ {
public: friend std::shared_ptr<Node> op::dot(const Node::ptr& arg0, const Node::ptr& arg1);
Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1) static std::shared_ptr<Dot> s_op;
{
return std::make_shared<Call>(shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
}
}; };
namespace op
{
extern decltype(*std::shared_ptr<Dot>()) dot;
}
} }
...@@ -15,7 +15,18 @@ ...@@ -15,7 +15,18 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
decltype(*std::shared_ptr<Broadcast>()) ngraph::op::broadcast = *std::make_shared<Broadcast>(); shared_ptr<Broadcast> ngraph::Broadcast::s_op = make_shared<ngraph::Broadcast>();
decltype(*std::shared_ptr<Dot>()) ngraph::op::dot = *std::make_shared<Dot>(); shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor, size_t axis)
{
return make_shared<Broadcast::BroadcastCall>(Broadcast::s_op->shared_from_this(), tensor, axis);
}
shared_ptr<Dot> ngraph::Dot::s_op = make_shared<ngraph::Dot>();
shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<Call>(Dot::s_op->shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
}
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
bool is_list() const { return m_is_list; } bool is_list() const { return m_is_list; }
T get_value() const { return m_value; } T get_value() const { return m_value; }
const std::vector<tree>& get_list() const { return m_list; } const std::vector<tree>& get_list() const { return m_list; }
static void traverse_tree(tree& s, std::function<void(T*)> func) static void traverse_tree(tree& s, std::function<void(T*)> func)
{ {
if (s.is_list()) if (s.is_list())
{ {
......
...@@ -83,10 +83,10 @@ namespace ngraph ...@@ -83,10 +83,10 @@ namespace ngraph
} }
size_t hash_combine(const std::vector<size_t>& list); size_t hash_combine(const std::vector<size_t>& list);
void dump(std::ostream& out, const void*, size_t); void dump(std::ostream& out, const void*, size_t);
std::string to_lower(const std::string& s); std::string to_lower(const std::string& s);
std::string trim(const std::string& s); std::string trim(const std::string& s);
std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false); std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
class stopwatch class stopwatch
...@@ -148,6 +148,7 @@ namespace ngraph ...@@ -148,6 +148,7 @@ namespace ngraph
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; } size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; } size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); } size_t get_total_nanoseconds() const { return m_total_time.count(); }
private: private:
std::chrono::high_resolution_clock m_clock; std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time; std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
......
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
return memcmp((char*)m_data, (char*)other.m_data, 16) == 0; return memcmp((char*)m_data, (char*)other.m_data, 16) == 0;
} }
bool operator!=(const uuid_type& other) const { return !(*this == other); } bool operator!=(const uuid_type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& out, const uuid_type& id) friend std::ostream& operator<<(std::ostream& out, const uuid_type& id)
{ {
out << id.to_string(); out << id.to_string();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment