Commit 064fb0fc authored by Scott Cyphers's avatar Scott Cyphers

formatting

parent fc0455ba
......@@ -15,8 +15,8 @@
#pragma once
#include <memory>
#include <vector>
#include <set>
#include <vector>
// Names for types that aren't worth giving their own classes
namespace ngraph
......@@ -25,7 +25,7 @@ namespace ngraph
class Parameter;
class ValueType;
template<typename T, typename ...A>
template <typename T, typename... A>
std::shared_ptr<T> node(A&&... args)
{
return std::make_shared<T>(args...);
......
......@@ -43,6 +43,7 @@ namespace ngraph
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
private:
static std::map<std::string, Type> m_element_list;
size_t m_bitwidth;
......@@ -54,15 +55,19 @@ namespace ngraph
// Provides a compile-time name for a C++ type.
// Used in TraitedType for the string that supplies the C++ type name during code generation,
// so it needs to be a valid C++ name.
template<typename T>
template <typename T>
const char* traited_type_name()
{
throw ngraph_error("Unknown type");
}
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN( T ) \
template<> constexpr const char* traited_type_name < T > () { return #T; }
// Define a type string for a type T. Will make traited_type_name<T>() return "T"
#define NGRAPH_DEFINE_TTN(T) \
template <> \
constexpr const char* traited_type_name<T>() \
{ \
return #T; \
}
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
......@@ -83,31 +88,32 @@ namespace ngraph
// This is the C++ type used to hold a value of this element type during compilation
using type = T;
// This returns a reference to an instance of this element type.
static const TraitedType<T>& element_type(){
static const TraitedType<T>& element_type()
{
static TraitedType<T> t;
return t;
}
};
NGRAPH_DEFINE_TTN( float )
NGRAPH_DEFINE_TTN(float)
using Float = TraitedType<float>;
NGRAPH_DEFINE_TTN( int8_t )
NGRAPH_DEFINE_TTN(int8_t)
using Int8 = TraitedType<int8_t>;
NGRAPH_DEFINE_TTN( int32_t )
NGRAPH_DEFINE_TTN(int32_t)
using Int32 = TraitedType<int32_t>;
NGRAPH_DEFINE_TTN( int64_t )
NGRAPH_DEFINE_TTN(int64_t)
using Int64 = TraitedType<int64_t>;
NGRAPH_DEFINE_TTN( uint8_t )
NGRAPH_DEFINE_TTN(uint8_t)
using UInt8 = TraitedType<uint8_t>;
NGRAPH_DEFINE_TTN( uint32_t )
NGRAPH_DEFINE_TTN(uint32_t)
using UInt32 = TraitedType<uint32_t>;
NGRAPH_DEFINE_TTN( uint64_t )
NGRAPH_DEFINE_TTN(uint64_t)
using UInt64 = TraitedType<uint64_t>;
}
}
......@@ -31,6 +31,7 @@ namespace ngraph
std::shared_ptr<Node> result() { return m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; }
protected:
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
......
......@@ -17,7 +17,8 @@
size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, std::shared_ptr<ValueType> type)
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> type)
: TypedValueMixin(type)
, m_arguments(arguments)
, m_instance_id(m_next_instance_id++)
......
......@@ -20,8 +20,8 @@
#include <iostream>
#include "type.hpp"
#include "common.hpp"
#include "type.hpp"
namespace ngraph
{
......@@ -32,11 +32,11 @@ namespace ngraph
/// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin, public std::enable_shared_from_this<Node>
{
protected:
Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr);
virtual ~Node() {}
public:
/// A "one-liner" describing this node.
virtual std::string description() const = 0;
......
......@@ -24,39 +24,51 @@ namespace ngraph
{
namespace op
{
std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution();
std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> get_tuple_element();
std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> logical(); and, or, not
std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> pad();
std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window();
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice();
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while();
}
......@@ -81,6 +93,7 @@ namespace ngraph
class FunctionOp : public Op
{
virtual std::string description() const override { return "FunctionOp"; }
protected:
std::shared_ptr<Node> m_function;
};
......@@ -95,6 +108,7 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected:
BuiltinOp(const std::vector<std::shared_ptr<Node>>& args)
: Op(args)
......
......@@ -25,7 +25,9 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
BroadcastOp(const std::shared_ptr<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
BroadcastOp(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
......
......@@ -16,7 +16,6 @@
namespace ngraph
{
class ConvertOp : public BuiltinOp
{
public:
......@@ -28,13 +27,14 @@ namespace ngraph
virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override;
protected:
const ngraph::element::Type& m_element_type;
};
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
}
}
......@@ -31,6 +31,7 @@ namespace ngraph
namespace op
{
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
}
}
......@@ -51,7 +51,8 @@ namespace ngraph
namespace op
{
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape);
......
......@@ -39,6 +39,7 @@ namespace ngraph
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected:
std::vector<size_t> m_sizes;
};
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
using namespace ngraph;
......@@ -21,7 +21,7 @@ using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn=m_dependent_nodes.begin(); dn!=m_dependent_nodes.end(); dn++)
for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
......@@ -30,7 +30,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first-1].push_back(n);
m_dependent_nodes[dn->first - 1].push_back(n);
}
}
}
......@@ -38,8 +38,7 @@ void ngraph::TopologicalSort::promote_node(Node* n)
void ngraph::TopologicalSort::process(node_ptr p)
{
traverse_nodes(p, [&](node_ptr node)
{
traverse_nodes(p, [&](node_ptr node) {
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get());
});
......
......@@ -14,9 +14,10 @@
#pragma once
#include <memory>
#include <map>
#include <list>
#include <map>
#include <memory>
#include <vector>
namespace ngraph
{
......
......@@ -71,7 +71,10 @@ namespace ngraph
{
}
const std::vector<std::shared_ptr<ValueType>> get_element_types() const { return m_element_types; }
const std::vector<std::shared_ptr<ValueType>> get_element_types() const
{
return m_element_types;
}
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
......@@ -95,7 +98,10 @@ namespace ngraph
** Set the type
** /param type The new type
**/
void set_value_type(const std::shared_ptr<ValueType>& value_type) { m_value_type = value_type; }
void set_value_type(const std::shared_ptr<ValueType>& value_type)
{
m_value_type = value_type;
}
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
......@@ -114,6 +120,7 @@ namespace ngraph
** The type associated with this value.
**/
const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
protected:
std::shared_ptr<ValueType> m_value_type;
};
......
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <list>
#include <fstream>
#include <cstdio>
#include <fstream>
#include <list>
#include "visualize.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
#include "visualize.hpp"
using namespace ngraph;
using namespace std;
......@@ -31,8 +31,7 @@ Visualize::Visualize(const string& name)
void Visualize::add(node_ptr p)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node)
{
traverse_nodes(p, [&](node_ptr node) {
for (auto arg : node->get_arguments())
{
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
......@@ -42,7 +41,7 @@ void Visualize::add(node_ptr p)
void Visualize::save_dot(const string& path) const
{
auto tmp_file = path+".tmp";
auto tmp_file = path + ".tmp";
ofstream out(tmp_file);
if (out)
{
......
......@@ -24,7 +24,8 @@ void ConvertOp::propagate_types()
throw ngraph_error("NIY");
}
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg, const element::Type& element_type)
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
......@@ -20,15 +20,18 @@ using namespace std;
using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
......
......@@ -32,12 +32,14 @@ std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
return make_shared<AbsOp>(arg);
}
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<AddOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
......@@ -45,7 +47,8 @@ std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, con
// 'convert',
// 'convolution',
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DivideOp>(arg0, arg1);
}
......@@ -55,7 +58,8 @@ std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
return make_shared<ExpOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<FloorOp>(arg0, arg1);
}
......@@ -65,17 +69,20 @@ std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
return make_shared<LogOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MaximumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MinimumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MultiplyOp>(arg0, arg1);
}
......@@ -87,14 +94,16 @@ std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
// 'pad',
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<RemainderOp>(arg0, arg1);
}
......@@ -109,7 +118,8 @@ std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, con
// 'select',
//'slice',
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<SubtractOp>(arg0, arg1);
}
......
......@@ -41,9 +41,7 @@ void Parameter::assign_function(Function* function, size_t index)
m_index = index;
}
void Parameter::propagate_types()
{
}
void Parameter::propagate_types() {}
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{
......
......@@ -99,6 +99,4 @@ TEST(build_graph, literal)
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
TEST(build_graph, arg_inverse) {}
......@@ -29,7 +29,7 @@ using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes)
{
bool rc = true;
for (auto it=nodes.rbegin(); it!=nodes.rend(); it++)
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
auto node_tmp = *it;
auto dependencies_tmp = node_tmp->get_arguments();
......@@ -38,8 +38,8 @@ static bool validate_list(const vector<Node*>& nodes)
{
dependencies.push_back(n.get());
}
auto tmp = it+1;
for (; tmp!=nodes.rend(); tmp++)
auto tmp = it + 1;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp);
......@@ -59,7 +59,7 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic)
{
vector<shared_ptr<Parameter>> args;
for (int i=0; i<10; i++)
for (int i = 0; i < 10; i++)
{
auto arg = op::parameter(element::Float::element_type(), {1});
ASSERT_NE(nullptr, arg);
......
......@@ -134,9 +134,7 @@ TEST(util, contains)
EXPECT_FALSE(contains(v1, 8));
}
TEST(util, remove_from)
{
}
TEST(util, remove_from) {}
TEST(util, reduce)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment