Commit f256e75d authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #21 from NervanaSystems/bob/namespace

put everything ngraph in the ngraph namespace
parents bdf38828 94f7f22c
...@@ -17,17 +17,17 @@ ...@@ -17,17 +17,17 @@
#include "element_type.hpp" #include "element_type.hpp"
const ElementType element_type_float = ElementType(32, true, true, "float"); const ngraph::ElementType element_type_float = ngraph::ElementType(32, true, true, "float");
const ElementType element_type_int8_t = ElementType(8, false, true, "int8_t"); const ngraph::ElementType element_type_int8_t = ngraph::ElementType(8, false, true, "int8_t");
const ElementType element_type_int32_t = ElementType(32, false, true, "int32_t"); const ngraph::ElementType element_type_int32_t = ngraph::ElementType(32, false, true, "int32_t");
const ElementType element_type_int64_t = ElementType(64, false, true, "int64_t"); const ngraph::ElementType element_type_int64_t = ngraph::ElementType(64, false, true, "int64_t");
const ElementType element_type_uint8_t = ElementType(8, false, false, "int8_t"); const ngraph::ElementType element_type_uint8_t = ngraph::ElementType(8, false, false, "int8_t");
const ElementType element_type_uint32_t = ElementType(32, false, false, "int32_t"); const ngraph::ElementType element_type_uint32_t = ngraph::ElementType(32, false, false, "int32_t");
const ElementType element_type_uint64_t = ElementType(64, false, false, "int64_t"); const ngraph::ElementType element_type_uint64_t = ngraph::ElementType(64, false, false, "int64_t");
std::map<std::string, ElementType> ElementType::m_element_list; std::map<std::string, ngraph::ElementType> ngraph::ElementType::m_element_list;
ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname) ngraph::ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname)
: m_bitwidth{bitwidth} : m_bitwidth{bitwidth}
, m_is_float{is_float} , m_is_float{is_float}
, m_is_signed{is_signed} , m_is_signed{is_signed}
...@@ -36,18 +36,18 @@ ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const s ...@@ -36,18 +36,18 @@ ElementType::ElementType(size_t bitwidth, bool is_float, bool is_signed, const s
assert(m_bitwidth % 8 == 0); assert(m_bitwidth % 8 == 0);
} }
const std::string& ElementType::c_type_string() const const std::string& ngraph::ElementType::c_type_string() const
{ {
return m_cname; return m_cname;
} }
bool ElementType::operator==(const ElementType& other) const bool ngraph::ElementType::operator==(const ElementType& other) const
{ {
return m_bitwidth == other.m_bitwidth && m_is_float == other.m_is_float && return m_bitwidth == other.m_bitwidth && m_is_float == other.m_is_float &&
m_is_signed == other.m_is_signed; m_is_signed == other.m_is_signed;
} }
size_t ElementType::size() const size_t ngraph::ElementType::size() const
{ {
return std::ceil((float)m_bitwidth / 8.0); return std::ceil((float)m_bitwidth / 8.0);
} }
...@@ -21,7 +21,12 @@ ...@@ -21,7 +21,12 @@
#include <string> #include <string>
#include <map> #include <map>
class ElementType namespace ngraph
{
class ElementType;
}
class ngraph::ElementType
{ {
public: public:
ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname); ElementType(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
...@@ -44,10 +49,10 @@ private: ...@@ -44,10 +49,10 @@ private:
const std::string m_cname; const std::string m_cname;
}; };
extern const ElementType element_type_float; extern const ngraph::ElementType element_type_float;
extern const ElementType element_type_int8_t; extern const ngraph::ElementType element_type_int8_t;
extern const ElementType element_type_int32_t; extern const ngraph::ElementType element_type_int32_t;
extern const ElementType element_type_int64_t; extern const ngraph::ElementType element_type_int64_t;
extern const ElementType element_type_uint8_t; extern const ngraph::ElementType element_type_uint8_t;
extern const ElementType element_type_uint32_t; extern const ngraph::ElementType element_type_uint32_t;
extern const ElementType element_type_uint64_t; extern const ngraph::ElementType element_type_uint64_t;
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "names.hpp" #include "names.hpp"
using namespace ngraph;
size_t NameableValue::__counter = 0; size_t NameableValue::__counter = 0;
std::map<std::string, NameableValue> NameableValue::__all_names; std::map<std::string, NameableValue> NameableValue::__all_names;
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <string> #include <string>
#include <map> #include <map>
namespace ngraph
{
//================================================================================================ //================================================================================================
// NameableValue // NameableValue
// An Axis labels a dimension of a tensor. The op-graph uses // An Axis labels a dimension of a tensor. The op-graph uses
...@@ -101,3 +104,6 @@ public: ...@@ -101,3 +104,6 @@ public:
std::string m_short_name; std::string m_short_name;
std::string m_doc_string; std::string m_doc_string;
}; };
} // end namespace ngraph
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "axes.hpp" #include "axes.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph;
slice::slice(int64_t start, int64_t stop, int64_t step) slice::slice(int64_t start, int64_t stop, int64_t step)
: m_start{(size_t)start} : m_start{(size_t)start}
, m_stop{(size_t)stop} , m_stop{(size_t)stop}
...@@ -90,12 +92,12 @@ size_t slice::sliced_length(size_t length) const ...@@ -90,12 +92,12 @@ size_t slice::sliced_length(size_t length) const
// raise TypeError("Could not cast {} to np.dtype".format(dtype)) // raise TypeError("Could not cast {} to np.dtype".format(dtype))
// return dtype // return dtype
Axis make_axis(size_t length, const std::string& name, bool batch, bool recurrent) Axis ngraph::make_axis(size_t length, const std::string& name, bool batch, bool recurrent)
{ {
return Axis(length, name); return Axis(length, name);
} }
Axes make_axes(const std::vector<Axis>& axis_list) Axes ngraph::make_axes(const std::vector<Axis>& axis_list)
{ {
return Axes(axis_list); return Axes(axis_list);
} }
...@@ -175,7 +177,7 @@ void Axis::length(size_t l) ...@@ -175,7 +177,7 @@ void Axis::length(size_t l)
__length = l; __length = l;
} }
std::ostream& operator<<(std::ostream& out, const Axis& axis) std::ostream& ngraph::operator<<(std::ostream& out, const Axis& axis)
{ {
out << axis.to_string(); out << axis.to_string();
return out; return out;
...@@ -238,7 +240,7 @@ bool Axis::operator<(const Axis& other) const ...@@ -238,7 +240,7 @@ bool Axis::operator<(const Axis& other) const
// // )) // // ))
// } // }
Axis slice_axis(const Axis& axis, const slice& s) Axis ngraph::slice_axis(const Axis& axis, const slice& s)
{ {
// _validate_slice(s) // _validate_slice(s)
...@@ -263,7 +265,7 @@ Axis slice_axis(const Axis& axis, const slice& s) ...@@ -263,7 +265,7 @@ Axis slice_axis(const Axis& axis, const slice& s)
// Returns: // Returns:
// list of Axis: duplicate Axis found in arr // list of Axis: duplicate Axis found in arr
// """ // """
std::vector<std::string> duplicates(const std::vector<Axis>& ax) std::vector<std::string> ngraph::duplicates(const std::vector<Axis>& ax)
{ {
std::map<std::string, size_t> counts; std::map<std::string, size_t> counts;
std::vector<std::string> rc; std::vector<std::string> rc;
...@@ -835,7 +837,7 @@ bool Axes::operator<(const Axes& other) const ...@@ -835,7 +837,7 @@ bool Axes::operator<(const Axes& other) const
// """ // """
// return int(np.prod(self.lengths)) // return int(np.prod(self.lengths))
std::ostream& operator<<(std::ostream& out, const Axes& axes) std::ostream& ngraph::operator<<(std::ostream& out, const Axes& axes)
{ {
out << "Axes("; out << "Axes(";
out << join(axes.axes, ", "); out << join(axes.axes, ", ");
...@@ -1060,7 +1062,7 @@ FlattenedAxis::FlattenedAxis(const std::vector<Axis>& list, const std::string& n ...@@ -1060,7 +1062,7 @@ FlattenedAxis::FlattenedAxis(const std::vector<Axis>& list, const std::string& n
axes = list; axes = list;
} }
std::ostream& operator<<(std::ostream& out, const FlattenedAxis& obj) std::ostream& ngraph::operator<<(std::ostream& out, const FlattenedAxis& obj)
{ {
out << obj.to_string(); out << obj.to_string();
return out; return out;
......
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include "strides.hpp" #include "strides.hpp"
#include "uuid.hpp" #include "uuid.hpp"
namespace ngraph
{
class Axes; class Axes;
class Axis; class Axis;
class FlattenedAxis; class FlattenedAxis;
...@@ -232,20 +234,6 @@ public: ...@@ -232,20 +234,6 @@ public:
static size_t __name_counter; static size_t __name_counter;
}; };
namespace std
{
template <>
struct std::hash<Axis>
{
size_t operator()(const Axis& axis) const
{
std::hash<std::string> h1;
std::hash<size_t> h2;
return hash_combine({h1(axis.name), h2(axis.length())});
}
};
}
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
// _sliced_length // _sliced_length
//----------------------------------------------------------------------------------------------- //-----------------------------------------------------------------------------------------------
...@@ -722,24 +710,6 @@ private: ...@@ -722,24 +710,6 @@ private:
void check_duplicates(); void check_duplicates();
}; };
namespace std
{
template <>
struct std::hash<Axes>
{
size_t operator()(const Axes& axes) const
{
std::hash<Axis> h1;
std::vector<size_t> hashes;
for (auto axis : axes)
{
hashes.push_back(h1(axis));
}
return hash_combine(hashes);
}
};
}
//================================================================================================ //================================================================================================
// DuplicateAxisNames // DuplicateAxisNames
//================================================================================================ //================================================================================================
...@@ -1518,3 +1488,37 @@ public: ...@@ -1518,3 +1488,37 @@ public:
ngraph::tensor_stride full_strides; ngraph::tensor_stride full_strides;
tensor_description_ptr next_tensor_description; tensor_description_ptr next_tensor_description;
}; };
} // end of namespace ngraph
namespace std
{
template <>
struct std::hash<ngraph::Axis>
{
size_t operator()(const ngraph::Axis& axis) const
{
std::hash<std::string> h1;
std::hash<size_t> h2;
return ngraph::hash_combine({h1(axis.name), h2(axis.length())});
}
};
}
namespace std
{
template <>
struct std::hash<ngraph::Axes>
{
size_t operator()(const ngraph::Axes& axes) const
{
std::hash<ngraph::Axis> h1;
std::vector<size_t> hashes;
for (auto axis : axes)
{
hashes.push_back(h1(axis));
}
return ngraph::hash_combine(hashes);
}
};
}
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "op_graph.hpp" #include "op_graph.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph;
//================================================================================================ //================================================================================================
// InputDecl // InputDecl
//================================================================================================ //================================================================================================
...@@ -77,7 +79,7 @@ void InputDecl::value(OutputDecl* value) ...@@ -77,7 +79,7 @@ void InputDecl::value(OutputDecl* value)
} }
} }
std::ostream& operator<<(std::ostream& out, const InputDecl& obj) std::ostream& ngraph::operator<<(std::ostream& out, const InputDecl& obj)
{ {
out << "Arg(" << obj.exop.name() << obj.pos << ")"; out << "Arg(" << obj.exop.name() << obj.pos << ")";
return out; return out;
...@@ -142,7 +144,7 @@ void OutputDecl::write_view(tensor_view_decl_ptr view) ...@@ -142,7 +144,7 @@ void OutputDecl::write_view(tensor_view_decl_ptr view)
} }
} }
std::ostream& operator<<(std::ostream& out, const OutputDecl& obj) std::ostream& ngraph::operator<<(std::ostream& out, const OutputDecl& obj)
{ {
out << "Val(" << obj.exop.name() << ":" << obj.pos << ")"; out << "Val(" << obj.exop.name() << ":" << obj.pos << ")";
return out; return out;
...@@ -191,7 +193,7 @@ ExOp::ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value) ...@@ -191,7 +193,7 @@ ExOp::ExOp(ComputationDecl& cgraph, op_ptr _op, bool create_value)
} }
} }
std::ostream& operator<<(std::ostream& out, const ExOp& obj) std::ostream& ngraph::operator<<(std::ostream& out, const ExOp& obj)
{ {
out << obj.op->name(); out << obj.op->name();
std::vector<std::string> args; std::vector<std::string> args;
...@@ -833,7 +835,7 @@ std::string TensorDecl::buffer_name() ...@@ -833,7 +835,7 @@ std::string TensorDecl::buffer_name()
// return op->name(); // return op->name();
// } // }
std::ostream& operator<<(std::ostream& out, const TensorDecl& obj) std::ostream& ngraph::operator<<(std::ostream& out, const TensorDecl& obj)
{ {
out << obj.tensor_description_base->name(); out << obj.tensor_description_base->name();
return out; return out;
......
...@@ -27,6 +27,9 @@ ...@@ -27,6 +27,9 @@
#include "op_graph.hpp" #include "op_graph.hpp"
#include "axes.hpp" #include "axes.hpp"
namespace ngraph
{
// forward declaration. This will hopefully go away // forward declaration. This will hopefully go away
class ExecutionGraph; class ExecutionGraph;
class TensorDescription; class TensorDescription;
...@@ -450,3 +453,5 @@ public: ...@@ -450,3 +453,5 @@ public:
std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls; std::map<tensor_description_ptr, tensor_decl_ptr> tensor_decls;
computation_decl_ptr computation_decl; computation_decl_ptr computation_decl;
}; };
} // end namespace ngraph
...@@ -23,6 +23,9 @@ ...@@ -23,6 +23,9 @@
#include "element_type.hpp" #include "element_type.hpp"
namespace ngraph
{
class ExecutionState; class ExecutionState;
class Op; class Op;
...@@ -175,3 +178,5 @@ public: ...@@ -175,3 +178,5 @@ public:
// private: // private:
// std::vector<op_ptr> m_all_deps; // std::vector<op_ptr> m_all_deps;
// }; // };
} // end of namespace ngraph
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include "mock.hpp" #include "mock.hpp"
#include "exop.hpp" #include "exop.hpp"
namespace ngraph
{
//================================================================================================ //================================================================================================
// CpuTransformer // CpuTransformer
//================================================================================================ //================================================================================================
...@@ -30,3 +33,5 @@ public: ...@@ -30,3 +33,5 @@ public:
private: private:
ExecutionState m_execution_state; ExecutionState m_execution_state;
}; };
} // end namespace ngraph
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "axes.hpp" #include "axes.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph;
// def tensor_descriptions(args): // def tensor_descriptions(args):
// """ // """
// A list of tensor descriptions for Ops. // A list of tensor descriptions for Ops.
...@@ -1913,7 +1915,7 @@ BroadcastOp::BroadcastOp(op_ptr x, Axes axes) ...@@ -1913,7 +1915,7 @@ BroadcastOp::BroadcastOp(op_ptr x, Axes axes)
// dx_reordered = axes_with_order(dx, x.axes) // dx_reordered = axes_with_order(dx, x.axes)
// x.generate_add_delta(adjoints, dx_reordered) // x.generate_add_delta(adjoints, dx_reordered)
op_ptr broadcast(op_ptr x, const Axes& axes) op_ptr ngraph::broadcast(op_ptr x, const Axes& axes)
{ {
// auto axes = make_axes(axis_list); // auto axes = make_axes(axis_list);
op_ptr rc; op_ptr rc;
...@@ -1928,7 +1930,7 @@ op_ptr broadcast(op_ptr x, const Axes& axes) ...@@ -1928,7 +1930,7 @@ op_ptr broadcast(op_ptr x, const Axes& axes)
return rc; return rc;
} }
op_ptr axes_with_order(op_ptr x, const std::vector<Axis>& axis_list) op_ptr ngraph::axes_with_order(op_ptr x, const std::vector<Axis>& axis_list)
{ {
auto axes = make_axes(axis_list); auto axes = make_axes(axis_list);
op_ptr rc; op_ptr rc;
......
...@@ -23,6 +23,9 @@ ...@@ -23,6 +23,9 @@
#include "axes.hpp" #include "axes.hpp"
#include "names.hpp" #include "names.hpp"
namespace ngraph
{
class Op; class Op;
class AssignableTensorOp; class AssignableTensorOp;
class ParallelOp; class ParallelOp;
...@@ -4427,3 +4430,5 @@ public: ...@@ -4427,3 +4430,5 @@ public:
// private: // private:
// std::vector<op_ptr> m_all_deps; // std::vector<op_ptr> m_all_deps;
}; };
} // end namespace ngraph
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
using namespace std; using namespace std;
map<string, stopwatch*> stopwatch_statistics; map<string, ngraph::stopwatch*> ngraph::stopwatch_statistics;
void dump(ostream& out, const void* _data, size_t _size) void ngraph::dump(ostream& out, const void* _data, size_t _size)
{ {
auto flags = out.flags(); auto flags = out.flags();
const uint8_t* data = reinterpret_cast<const uint8_t*>(_data); const uint8_t* data = reinterpret_cast<const uint8_t*>(_data);
...@@ -66,14 +66,14 @@ void dump(ostream& out, const void* _data, size_t _size) ...@@ -66,14 +66,14 @@ void dump(ostream& out, const void* _data, size_t _size)
out.flags(flags); out.flags(flags);
} }
std::string to_lower(const std::string& s) std::string ngraph::to_lower(const std::string& s)
{ {
std::string rc = s; std::string rc = s;
std::transform(rc.begin(), rc.end(), rc.begin(), ::tolower); std::transform(rc.begin(), rc.end(), rc.begin(), ::tolower);
return rc; return rc;
} }
string trim(const string& s) string ngraph::trim(const string& s)
{ {
string rc = s; string rc = s;
// trim trailing spaces // trim trailing spaces
...@@ -92,7 +92,7 @@ string trim(const string& s) ...@@ -92,7 +92,7 @@ string trim(const string& s)
return rc; return rc;
} }
vector<string> split(const string& src, char delimiter, bool do_trim) vector<string> ngraph::split(const string& src, char delimiter, bool do_trim)
{ {
size_t pos; size_t pos;
string token; string token;
...@@ -120,7 +120,7 @@ vector<string> split(const string& src, char delimiter, bool do_trim) ...@@ -120,7 +120,7 @@ vector<string> split(const string& src, char delimiter, bool do_trim)
return rc; return rc;
} }
size_t hash_combine(const std::vector<size_t>& list) size_t ngraph::hash_combine(const std::vector<size_t>& list)
{ {
size_t seed = 0; size_t seed = 0;
for (size_t v : list) for (size_t v : list)
......
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include <map> #include <map>
#include <iostream> #include <iostream>
namespace ngraph
{
class stopwatch; class stopwatch;
extern std::map<std::string, stopwatch*> stopwatch_statistics; extern std::map<std::string, stopwatch*> stopwatch_statistics;
...@@ -157,39 +160,38 @@ private: ...@@ -157,39 +160,38 @@ private:
std::string m_name; std::string m_name;
}; };
namespace ngraph template <class InputIt, class BinaryOp>
typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op)
{ {
template <class InputIt, class BinaryOp> typename std::iterator_traits<InputIt>::value_type result;
typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op)
{
typename std::iterator_traits<InputIt>::value_type result;
if (first == last) if (first == last)
{ {
result = {}; result = {};
} }
else else
{
result = *first++;
while (first != last)
{ {
result = *first++; result = op(result, *first);
while (first != last) first++;
{
result = op(result, *first);
first++;
}
} }
return result;
} }
return result;
}
template <typename T> template <typename T>
T plus(const T& a, const T& b) T plus(const T& a, const T& b)
{ {
return a + b; return a + b;
} }
template <typename T> template <typename T>
T mul(const T& a, const T& b) T mul(const T& a, const T& b)
{ {
return a * b; return a * b;
}
} }
} // end namespace ngraph
...@@ -22,7 +22,12 @@ ...@@ -22,7 +22,12 @@
static std::mt19937_64 random_generator; static std::mt19937_64 random_generator;
class uuid_type namespace ngraph
{
class uuid_type;
}
class ngraph::uuid_type
{ {
public: public:
uuid_type() uuid_type()
......
...@@ -27,7 +27,7 @@ set (SRC ...@@ -27,7 +27,7 @@ set (SRC
exop.cpp exop.cpp
axes.cpp axes.cpp
element_type.cpp element_type.cpp
op.cpp op_graph.cpp
uuid.cpp uuid.cpp
names.cpp names.cpp
strides.cpp strides.cpp
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "transformers/ndarray.hpp" #include "transformers/ndarray.hpp"
using namespace std; using namespace std;
using namespace ngraph;
// axes for testing // axes for testing
static auto ax_A = make_axis(2, "A"); static auto ax_A = make_axis(2, "A");
......
...@@ -19,3 +19,5 @@ ...@@ -19,3 +19,5 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "element_type.hpp" #include "element_type.hpp"
using namespace ngraph;
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "transformers/mock.hpp" #include "transformers/mock.hpp"
#include "transformers/mock_transformer.hpp" #include "transformers/mock_transformer.hpp"
using namespace ngraph;
TEST(exop, create) TEST(exop, create)
{ {
// CpuTransformer transformer; // CpuTransformer transformer;
......
...@@ -20,4 +20,6 @@ ...@@ -20,4 +20,6 @@
#include "names.hpp" #include "names.hpp"
using namespace ngraph;
TEST(names, name) {} TEST(names, name) {}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include "gtest/gtest.h"
#include "transformers/op_graph.hpp"
TEST(op, constant)
{
float expected_value = 42;
op_ptr x = constant(expected_value);
ASSERT_NE(nullptr, x);
EXPECT_EQ(true, x->is_constant());
EXPECT_EQ(false, x->is_input());
EXPECT_EQ(true, x->is_persistent());
EXPECT_EQ(false, x->is_trainable());
EXPECT_EQ(false, x->is_placeholder());
auto ato = std::dynamic_pointer_cast<AssignableTensorOp>(x);
ASSERT_NE(nullptr, ato);
// TODO: fix this
auto ti = ato->m_value;
ASSERT_NE(nullptr, ti);
std::string actual_value = ti->value_string();
std::stringstream ss;
ss << expected_value;
std::string expected_string = ss.str();
EXPECT_STREQ(actual_value.c_str(), expected_string.c_str());
}
// @pytest.fixture()
// def N():
// return ng.make_axis(length=1)
// def test_deriv_missing_connection(N):
// """
// Taking the derivative of an expression with respect to a variable not
// used to compute the expression should raise an exception.
// """
// x = ng.variable([N])
// y = ng.variable([N])
// z = ng.variable([N])
// with pytest.raises(ValueError):
// ng.deriv(x + y, z)
// def test_one():
// # Test that the cacheing on constant one used in DerivOp works.
// op = ng.variable([])
// one_0 = op.one
// one_1 = op.one
// assert one_0 is one_1
// def test_pad_invalid_paddings_length(N):
// """
// pad should raise an exception if the paddings length is not the same as the
// input dimensionality.
// """
// x = ng.variable([N])
// with pytest.raises(ValueError):
// ng.pad(x, [1, 0])
// def test_pad_0(N):
// """
// pad with length 0 should be a nop
// """
// x = ng.variable([N])
// assert ng.pad(x, [0]).axes == x.axes
// def test_pad_mixed():
// """
// mix 0 padding with non-0 padding
// """
// input_axes = ng.make_axes([
// ng.make_axis(1),
// ng.make_axis(1)
// ])
// x = ng.variable(input_axes)
// pad = ng.pad(x, [0, 1])
// assert pad.axes[0] == x.axes[0]
// assert pad.axes[1] != x.axes[1]
// def test_slice_nop():
// """
// slicing an axis shouldn't change the name
// """
// input_axes = ng.make_axes([
// ng.make_axis(1),
// ng.make_axis(1)
// ])
// x = ng.variable(input_axes)
// s = ng.tensor_slice(x, [
// slice(None, None, None),
// slice(None, None, 1),
// ])
// assert s.axes[0] == x.axes[0]
// assert s.axes[1] == x.axes[1]
// def test_tensor_slice():
// """
// slicing a tensor should work like numpy
// """
// input_axes = ng.make_axes([
// ng.make_axis(10),
// ng.make_axis(20),
// ng.make_axis(5)
// ])
// x = ng.placeholder(axes=input_axes)
// assert x[:5].axes.full_lengths == (5, 20, 5)
// assert x[:, 2:7].axes.full_lengths == (10, 5, 5)
// assert x[:5, :, :-1].axes.full_lengths == (5, 20, 4)
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "transformers/op_graph.hpp" #include "transformers/op_graph.hpp"
using namespace ngraph;
TEST(op_graph, constant) TEST(op_graph, constant)
{ {
float expected_value = 42; float expected_value = 42;
...@@ -62,7 +64,7 @@ Axis N() ...@@ -62,7 +64,7 @@ Axis N()
TEST(op_graph, deriv_missing_connection) TEST(op_graph, deriv_missing_connection)
{ {
// x = ng.variable([N]) // x = ng.variable([N])
auto x = variable({N()}); // auto x = variable({N()});
// y = ng.variable([N]) // y = ng.variable([N])
// z = ng.variable([N]) // z = ng.variable([N])
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "strides.hpp" #include "strides.hpp"
using namespace std; using namespace std;
using namespace ngraph;
TEST(strides, scalar_tree_ctor) TEST(strides, scalar_tree_ctor)
{ {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "util.hpp" #include "util.hpp"
using namespace std; using namespace std;
using namespace ngraph;
TEST(util, split) TEST(util, split)
{ {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "uuid.hpp" #include "uuid.hpp"
using namespace std; using namespace std;
using namespace ngraph;
TEST(uuid, zero) TEST(uuid, zero)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment