Commit f73e0238 authored by Scott Cyphers's avatar Scott Cyphers

Renamings.

parent c92cabf6
......@@ -24,6 +24,7 @@ namespace ngraph
{
namespace op
{
Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
......@@ -31,7 +32,7 @@ namespace ngraph
//Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0);
Node::ptr exp(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get_tuple_element();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
......@@ -43,7 +44,7 @@ namespace ngraph
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr negate(const Node::ptr& arg0);
Node::ptr negative(const Node::ptr& arg0);
//Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce();
......@@ -166,10 +167,10 @@ namespace ngraph
//virtual void propagate_types() override;
};
class ExponentialOp : public BuiltinOp
class ExpOp : public BuiltinOp
{
public:
ExponentialOp(const Node::ptr& arg0)
ExpOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
......@@ -262,15 +263,15 @@ namespace ngraph
//virtual void propagate_types() override;
};
class NegateOp : public BuiltinOp
class NegativeOp : public BuiltinOp
{
public:
NegateOp(const Node::ptr& arg0)
NegativeOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_name() const override { return "negate"; }
virtual std::string op_name() const override { return "negative"; }
//virtual void propagate_types() override;
};
......
......@@ -21,10 +21,10 @@ namespace ngraph
Node::ptr concatenate(const Nodes& args);
}
class ConcatenateOp : public BuiltinOp
class ConcatOp : public BuiltinOp
{
public:
ConcatenateOp(const Nodes& args)
ConcatOp(const Nodes& args)
: BuiltinOp(args)
{
}
......
......@@ -19,12 +19,12 @@
using namespace std;
using namespace ngraph;
void ConcatenateOp::propagate_types()
void ConcatOp::propagate_types()
{
throw ngraph_error("NIY");
}
Node::ptr op::concatenate(const std::vector<Node::ptr>& args)
{
return make_shared<ConcatenateOp>(args);
return make_shared<ConcatOp>(args);
}
......@@ -42,9 +42,9 @@ Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<DivideOp>(arg0, arg1);
}
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
Node::ptr ngraph::op::exp(const Node::ptr& arg0)
{
return make_shared<ExponentialOp>(arg0);
return make_shared<ExpOp>(arg0);
}
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
......@@ -72,9 +72,9 @@ Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<MultiplyOp>(arg0, arg1);
}
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
Node::ptr ngraph::op::negative(const Node::ptr& arg0)
{
return make_shared<NegateOp>(arg0);
return make_shared<NegativeOp>(arg0);
}
// 'pad',
......
......@@ -16,18 +16,34 @@
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
template<typename T, typename ...A>
std::shared_ptr<T> myfun(A&&... args)
{
return std::make_shared<T>(args...);
}
template<>
std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_type, Shape&& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = op::parameter(element::Float::type, {7, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg0 = myfun<Parameter>(element::Float::type, Shape{7, 3});
auto arg1 = op::parameter(element::Float::type, Shape{3});
auto arg2 = op::parameter(element::Float::type, Shape{32, 7});
auto arg3 = op::parameter(element::Float::type, Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto dot = op::dot(arg2, arg0);
auto d1 = myfun<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
......@@ -96,4 +112,7 @@ TEST(build_graph, literal)
}
// Check argument inverses
TEST(build_graph, arg_inverse) {}
TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment