Commit 4aa19e31 authored by Scott Cyphers's avatar Scott Cyphers

Add tuple, explicit constants, move some ops to files.

parent b4104ebe
...@@ -36,6 +36,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast") ...@@ -36,6 +36,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-conversion") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-conversion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-potentially-evaluated-expression") # Triggers false alarms on typeid
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion")
......
...@@ -18,11 +18,14 @@ set (SRC ...@@ -18,11 +18,14 @@ set (SRC
tree.cpp tree.cpp
util.cpp util.cpp
log.cpp log.cpp
ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/constant.cpp
ops/dot.cpp
ops/function.cpp ops/function.cpp
ops/literal.cpp
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/tuple.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
) )
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/literal.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/parameter.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
...@@ -26,18 +26,11 @@ namespace ngraph ...@@ -26,18 +26,11 @@ namespace ngraph
{ {
Node::ptr abs(const Node::ptr& arg); Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
//Node::ptr candidate(); //Node::ptr candidate();
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr constant();
//Node::ptr convert(); //Node::ptr convert();
//Node::ptr convolution(); //Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0); Node::ptr exponential(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
...@@ -61,7 +54,6 @@ namespace ngraph ...@@ -61,7 +54,6 @@ namespace ngraph
//Node::ptr slice(); //Node::ptr slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1); Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr transpose(); //Node::ptr transpose();
//Node::ptr tuple();
//Node::ptr while(); //Node::ptr while();
} }
...@@ -135,30 +127,6 @@ namespace ngraph ...@@ -135,30 +127,6 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class BroadcastOp : public BuiltinOp
{
public:
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
};
class CeilingOp : public BuiltinOp class CeilingOp : public BuiltinOp
{ {
public: public:
...@@ -183,19 +151,6 @@ namespace ngraph ...@@ -183,19 +151,6 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class DotOp : public BuiltinOp
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override;
};
class EqualOp : public BuiltinOp class EqualOp : public BuiltinOp
{ {
public: public:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
class BroadcastOp : public BuiltinOp
{
public:
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
};
namespace op
{
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
}
}
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace ngraph namespace ngraph
{ {
// Defines methods to all literal scalars // Defines methods to all constant scalars
class ScalarLiteralBaseOp : public Node class ScalarConstantBaseOp : public Node
{ {
protected: protected:
ScalarLiteralBaseOp(const std::shared_ptr<TensorViewType>& type) ScalarConstantBaseOp(const std::shared_ptr<TensorViewType>& type)
: Node({}, type) : Node({}, type)
{ {
} }
...@@ -30,10 +30,10 @@ namespace ngraph ...@@ -30,10 +30,10 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
// Implement a literal scalar for each element type. // Implement a constant scalar for each element type.
// The static make method takes a // The static make method takes a
template <typename T> template <typename T>
class ScalarLiteralOp : public ScalarLiteralBaseOp class ScalarConstantOp : public ScalarConstantBaseOp
{ {
public: public:
// The ngraph element type // The ngraph element type
...@@ -41,34 +41,34 @@ namespace ngraph ...@@ -41,34 +41,34 @@ namespace ngraph
// The C++ type that holds the element type // The C++ type that holds the element type
using ctype = typename T::ctype; using ctype = typename T::ctype;
ScalarLiteralOp(typename T::ctype value) ScalarConstantOp(typename T::ctype value)
: ScalarLiteralBaseOp(std::make_shared<TensorViewType>(T::type, ngraph::Shape{})) : ScalarConstantBaseOp(std::make_shared<TensorViewType>(T::type, ngraph::Shape{}))
, m_value(value) , m_value(value)
{ {
} }
virtual std::string description() const override { return "LiteralScalar"; } virtual std::string description() const override { return "ConstantScalar"; }
typename T::ctype value() const { return m_value; } typename T::ctype value() const { return m_value; }
// Make a literal from any value that can be converted to the C++ type we use // Make a constant from any value that can be converted to the C++ type we use
// to represent the values. // to represent the values.
template <typename U> template <typename U>
static std::shared_ptr<ScalarLiteralOp<T>> make(U value) static std::shared_ptr<ScalarConstantOp<T>> make(U value)
{ {
return std::make_shared<ScalarLiteralOp<T>>( return std::make_shared<ScalarConstantOp<T>>(
static_cast<ScalarLiteralOp<T>::ctype>(value)); static_cast<ScalarConstantOp<T>::ctype>(value));
} }
protected: protected:
typename T::ctype m_value; typename T::ctype m_value;
}; };
using FloatScalarOp = ScalarLiteralOp<element::Float>; using FloatScalarConstantOp = ScalarConstantOp<element::Float>;
using Int8ScalarOp = ScalarLiteralOp<element::Int8>; using Int8ScalarConstantOp = ScalarConstantOp<element::Int8>;
using Int32ScalarOp = ScalarLiteralOp<element::Int32>; using Int32ScalarConstantOp = ScalarConstantOp<element::Int32>;
using Int64ScalarOp = ScalarLiteralOp<element::Int64>; using Int64ScalarConstantOp = ScalarConstantOp<element::Int64>;
using UInt8ScalarOp = ScalarLiteralOp<element::UInt8>; using UInt8ScalarConstantOp = ScalarConstantOp<element::UInt8>;
using UInt32ScalarOp = ScalarLiteralOp<element::UInt32>; using UInt32ScalarConstantOp = ScalarConstantOp<element::UInt32>;
using UInt64ScalarOp = ScalarLiteralOp<element::UInt64>; using UInt64ScalarConstantOp = ScalarConstantOp<element::UInt64>;
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
class DotOp : public BuiltinOp
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override;
};
namespace op
{
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
Node::ptr tuple(const std::vector<Node::ptr>& args);
}
class TupleOp : public BuiltinOp
{
public:
TupleOp(const std::vector<Node::ptr>& args)
: BuiltinOp(args)
{
}
virtual std::string op_name() const override { return "tuple"; }
virtual void propagate_types() override;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
...@@ -16,4 +16,4 @@ ...@@ -16,4 +16,4 @@
using namespace ngraph; using namespace ngraph;
void ScalarLiteralBaseOp::propagate_types() {} void ScalarConstantBaseOp::propagate_types() {}
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
...@@ -29,52 +29,11 @@ Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -29,52 +29,11 @@ Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<AddOp>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<CeilingOp>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
// 'concatenate',
// 'constant',
// 'convert', // 'convert',
// 'convolution', // 'convolution',
...@@ -83,50 +42,6 @@ Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -83,50 +42,6 @@ Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<DivideOp>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{ {
return make_shared<ExponentialOp>(arg0); return make_shared<ExponentialOp>(arg0);
...@@ -163,7 +78,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0) ...@@ -163,7 +78,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
} }
// 'pad', // 'pad',
// 'parameter',
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
...@@ -193,5 +107,4 @@ Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -193,5 +107,4 @@ Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
} }
// 'transpose', // 'transpose',
//'tuple',
// 'while' // 'while'
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
void TupleOp::propagate_types()
{
throw ngraph_error("NIY");
}
Node::ptr op::tuple(const std::vector<Node::ptr>& args)
{
return make_shared<TupleOp>(args);
}
...@@ -75,7 +75,7 @@ TEST(build_graph, node_comparison) ...@@ -75,7 +75,7 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal) TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
auto float0 = FloatScalarOp::make(3.0); auto float0 = FloatScalarConstantOp::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{});
ASSERT_EQ(float0->value(), 3.0); ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type); ASSERT_EQ(*float0->type(), float_scalar_type);
...@@ -84,11 +84,11 @@ TEST(build_graph, literal) ...@@ -84,11 +84,11 @@ TEST(build_graph, literal)
ASSERT_EQ(d->arguments().at(1), float0); ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = FloatScalarOp::make(3); auto float1 = FloatScalarConstantOp::make(3);
ASSERT_EQ(float1->value(), 3); ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type); ASSERT_EQ(*float1->type(), float_scalar_type);
auto int32_0 = Int32ScalarOp::make(3.0); auto int32_0 = Int32ScalarConstantOp::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{});
ASSERT_EQ(int32_0->value(), 3); ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type); ASSERT_EQ(*int32_0->type(), int32_scalar_type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment