Commit f73e0238 authored by Scott Cyphers's avatar Scott Cyphers

Renamings.

parent c92cabf6
...@@ -24,6 +24,7 @@ namespace ngraph ...@@ -24,6 +24,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
Node::ptr abs(const Node::ptr& arg); Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
...@@ -31,7 +32,7 @@ namespace ngraph ...@@ -31,7 +32,7 @@ namespace ngraph
//Node::ptr convolution(); //Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0); Node::ptr exp(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get_tuple_element(); //Node::ptr get_tuple_element();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
...@@ -43,7 +44,7 @@ namespace ngraph ...@@ -43,7 +44,7 @@ namespace ngraph
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr negate(const Node::ptr& arg0); Node::ptr negative(const Node::ptr& arg0);
//Node::ptr pad(); //Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce(); //Node::ptr reduce();
...@@ -166,10 +167,10 @@ namespace ngraph ...@@ -166,10 +167,10 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ExponentialOp : public BuiltinOp class ExpOp : public BuiltinOp
{ {
public: public:
ExponentialOp(const Node::ptr& arg0) ExpOp(const Node::ptr& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -262,15 +263,15 @@ namespace ngraph ...@@ -262,15 +263,15 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class NegateOp : public BuiltinOp class NegativeOp : public BuiltinOp
{ {
public: public:
NegateOp(const Node::ptr& arg0) NegativeOp(const Node::ptr& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
virtual std::string op_name() const override { return "negate"; } virtual std::string op_name() const override { return "negative"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
......
...@@ -21,10 +21,10 @@ namespace ngraph ...@@ -21,10 +21,10 @@ namespace ngraph
Node::ptr concatenate(const Nodes& args); Node::ptr concatenate(const Nodes& args);
} }
class ConcatenateOp : public BuiltinOp class ConcatOp : public BuiltinOp
{ {
public: public:
ConcatenateOp(const Nodes& args) ConcatOp(const Nodes& args)
: BuiltinOp(args) : BuiltinOp(args)
{ {
} }
......
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void ConcatenateOp::propagate_types() void ConcatOp::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
Node::ptr op::concatenate(const std::vector<Node::ptr>& args) Node::ptr op::concatenate(const std::vector<Node::ptr>& args)
{ {
return make_shared<ConcatenateOp>(args); return make_shared<ConcatOp>(args);
} }
...@@ -42,9 +42,9 @@ Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -42,9 +42,9 @@ Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<DivideOp>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) Node::ptr ngraph::op::exp(const Node::ptr& arg0)
{ {
return make_shared<ExponentialOp>(arg0); return make_shared<ExpOp>(arg0);
} }
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
...@@ -72,9 +72,9 @@ Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -72,9 +72,9 @@ Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<MultiplyOp>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
Node::ptr ngraph::op::negate(const Node::ptr& arg0) Node::ptr ngraph::op::negative(const Node::ptr& arg0)
{ {
return make_shared<NegateOp>(arg0); return make_shared<NegativeOp>(arg0);
} }
// 'pad', // 'pad',
......
...@@ -16,18 +16,34 @@ ...@@ -16,18 +16,34 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include <memory>
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
template<typename T, typename ...A>
std::shared_ptr<T> myfun(A&&... args)
{
return std::make_shared<T>(args...);
}
template<>
std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_type, Shape&& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = op::parameter(element::Float::type, {7, 3}); auto arg0 = myfun<Parameter>(element::Float::type, Shape{7, 3});
auto arg1 = op::parameter(element::Float::type, {3}); auto arg1 = op::parameter(element::Float::type, Shape{3});
auto arg2 = op::parameter(element::Float::type, {32, 7}); auto arg2 = op::parameter(element::Float::type, Shape{32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7}); auto arg3 = op::parameter(element::Float::type, Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
auto d1 = myfun<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
...@@ -96,4 +112,7 @@ TEST(build_graph, literal) ...@@ -96,4 +112,7 @@ TEST(build_graph, literal)
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) {} TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment