Commit a136956b authored by Scott Cyphers's avatar Scott Cyphers

Two shape propagates/checks, bulk of ops.

parent 9d40c6b2
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
public: public:
Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname); Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const; const std::string& c_type_string() const;
size_t size() const; size_t size() const;
size_t hash() const size_t hash() const
...@@ -37,23 +37,24 @@ namespace ngraph ...@@ -37,23 +37,24 @@ namespace ngraph
std::hash<std::string> h; std::hash<std::string> h;
return h(m_cname); return h(m_cname);
} }
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
private: private:
static std::map<std::string, Type> m_element_list; static std::map<std::string, Type> m_element_list;
size_t m_bitwidth; size_t m_bitwidth;
bool m_is_float; bool m_is_float;
bool m_is_signed; bool m_is_signed;
const std::string m_cname; const std::string m_cname;
}; };
const Type float32_t= Type(32, true, true, "float"); const Type float32_t = Type(32, true, true, "float");
const Type int8_t = Type(8, false, true, "int8_t"); const Type int8_t = Type(8, false, true, "int8_t");
const Type int32_t = Type(32, false, true, "int32_t"); const Type int32_t = Type(32, false, true, "int32_t");
const Type int64_t = Type(64, false, true, "int64_t"); const Type int64_t = Type(64, false, true, "int64_t");
const Type uint8_t = Type(8, false, false, "int8_t"); const Type uint8_t = Type(8, false, false, "int8_t");
const Type uint32_t = Type(32, false, false, "int32_t"); const Type uint32_t = Type(32, false, false, "int32_t");
const Type uint64_t = Type(64, false, false, "int64_t"); const Type uint64_t = Type(64, false, false, "int64_t");
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <stdexcept>
namespace ngraph
{
/// Base error for ngraph runtime errors.
struct ngraph_error : std::runtime_error
{
explicit ngraph_error(const std::string& what_arg)
: std::runtime_error(what_arg)
{
}
explicit ngraph_error(const char* what_arg)
: std::runtime_error(what_arg)
{
}
};
}
...@@ -35,6 +35,8 @@ namespace ngraph ...@@ -35,6 +35,8 @@ namespace ngraph
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
protected: protected:
Function& m_function; Function& m_function;
size_t m_index; size_t m_index;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#pragma once #pragma once
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
......
...@@ -53,7 +53,7 @@ namespace ngraph ...@@ -53,7 +53,7 @@ namespace ngraph
virtual std::string description() const = 0; virtual std::string description() const = 0;
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
// virtual void propagate_types() = 0; virtual void propagate_types() = 0;
const std::vector<Node::ptr> arguments() const { return m_arguments; } const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; } std::vector<Node::ptr> arguments() { return m_arguments; }
......
...@@ -21,6 +21,49 @@ ...@@ -21,6 +21,49 @@
namespace ngraph namespace ngraph
{ {
namespace op
{
Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
//Node::ptr candidate();
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr concatenate();
//Node::ptr constant();
//Node::ptr convert();
//Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr log(const Node::ptr& arg0);
//Node::ptr logical();
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr negate(const Node::ptr& arg0);
//Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce();
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape);
//Node::ptr reverse();
//Node::ptr rng();
//Node::ptr select();
//Node::ptr slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr transpose();
//Node::ptr tuple();
//Node::ptr while();
}
/** /**
** Every instance of Op corresponds to a unique defined operation. ** Every instance of Op corresponds to a unique defined operation.
**/ **/
...@@ -82,6 +125,9 @@ namespace ngraph ...@@ -82,6 +125,9 @@ namespace ngraph
public: public:
virtual std::string description() const override { return "BuiltinCall"; } virtual std::string description() const override { return "BuiltinCall"; }
// TODO: Implement for each op
virtual void propagate_types() override {}
protected: protected:
BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args) BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args)
: Call(op, args) : Call(op, args)
...@@ -89,12 +135,29 @@ namespace ngraph ...@@ -89,12 +135,29 @@ namespace ngraph
} }
}; };
namespace op class AbsCall : public BuiltinCall
{ {
std::shared_ptr<Node> broadcast(const Node::ptr& tensor, public:
const Shape& shape, AbsCall(const Node::ptr& arg0)
const std::vector<size_t>& broadcast_axes); : BuiltinCall(s_op, {arg0})
} {
}
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class AddCall : public BuiltinCall
{
public:
AddCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class BroadcastCall : public BuiltinCall class BroadcastCall : public BuiltinCall
{ {
...@@ -111,17 +174,39 @@ namespace ngraph ...@@ -111,17 +174,39 @@ namespace ngraph
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
} }
virtual void propagate_types() override;
protected:
Shape m_shape; Shape m_shape;
std::vector<size_t> m_broadcast_axes; std::vector<size_t> m_broadcast_axes;
static std::shared_ptr<BuiltinOp> s_op;
};
class CeilingCall : public BuiltinCall
{
public:
CeilingCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static std::shared_ptr<BuiltinOp> s_op;
}; };
namespace op class DivideCall : public BuiltinCall
{ {
std::shared_ptr<Node> dot(const Node::ptr& arg0, const Node::ptr& arg1); public:
} DivideCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class DotCall : public BuiltinCall class DotCall : public BuiltinCall
{ {
...@@ -131,7 +216,179 @@ namespace ngraph ...@@ -131,7 +216,179 @@ namespace ngraph
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall(s_op, {arg0, arg1})
{ {
} }
virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class EqualCall : public BuiltinCall
{
public:
EqualCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class ExponentialCall : public BuiltinCall
{
public:
ExponentialCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class FloorCall : public BuiltinCall
{
public:
FloorCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class GreaterCall : public BuiltinCall
{
public:
GreaterCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class LessCall : public BuiltinCall
{
public:
LessCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class LogCall : public BuiltinCall
{
public:
LogCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class MaximumCall : public BuiltinCall
{
public:
MaximumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class MinimumCall : public BuiltinCall
{
public:
MinimumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class MultiplyCall : public BuiltinCall
{
public:
MultiplyCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class NegateCall : public BuiltinCall
{
public:
NegateCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class PowerCall : public BuiltinCall
{
public:
PowerCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class RemainderCall : public BuiltinCall
{
public:
RemainderCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
class ReshapeCall : public BuiltinCall
{
public:
ReshapeCall(const Node::ptr& arg0, const Shape& shape)
: BuiltinCall(s_op, {arg0})
, m_shape(shape)
{
}
//virtual void propagate_types() override;
protected:
Shape m_shape;
static std::shared_ptr<BuiltinOp> s_op;
};
class SubtractCall : public BuiltinCall
{
public:
SubtractCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
//virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static std::shared_ptr<BuiltinOp> s_op;
}; };
......
...@@ -32,11 +32,19 @@ namespace ngraph ...@@ -32,11 +32,19 @@ namespace ngraph
{ {
} }
Shape(const std::vector<size_t>& sizes)
: m_sizes(sizes)
{
}
/** /**
** Conversion to a vector of sizes. ** Conversion to a vector of sizes.
**/ **/
operator const std::vector<size_t>&() const { return m_sizes; } operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected: protected:
std::vector<size_t> m_sizes; std::vector<size_t> m_sizes;
}; };
......
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
namespace ngraph namespace ngraph
{ {
class TensorViewType;
class TupleType;
/** /**
** ValueType is ** ValueType is
** TensorViewType ** TensorViewType
...@@ -34,6 +37,10 @@ namespace ngraph ...@@ -34,6 +37,10 @@ namespace ngraph
** Preferred handle ** Preferred handle
**/ **/
using ptr = std::shared_ptr<ValueType>; using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {}
virtual std::shared_ptr<TensorViewType> as_tensor_view_type() { return nullptr; }
virtual std::shared_ptr<TupleType> as_tuple_type() { return nullptr; }
}; };
/** /**
...@@ -57,6 +64,9 @@ namespace ngraph ...@@ -57,6 +64,9 @@ namespace ngraph
{ {
} }
const element::Type& element_type() const { return m_element_type; }
const Shape shape() const { return m_shape; }
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
Shape m_shape; Shape m_shape;
......
...@@ -24,6 +24,14 @@ Parameter::Parameter(Function& function, size_t index) ...@@ -24,6 +24,14 @@ Parameter::Parameter(Function& function, size_t index)
{ {
} }
void Parameter::propagate_types()
{
if (m_type == nullptr)
{
throw ngraph_error{"Unitialized parameter"};
}
}
Function::Function(size_t n_parameters) Function::Function(size_t n_parameters)
: m_parameters(n_parameters) : m_parameters(n_parameters)
, m_name("Function") , m_name("Function")
......
...@@ -12,11 +12,27 @@ ...@@ -12,11 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::shared_ptr<BuiltinOp> AbsCall::s_op = make_shared<BuiltinOp>("abs");
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsCall>(arg);
}
std::shared_ptr<BuiltinOp> AddCall::s_op = make_shared<BuiltinOp>("add");
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<AddCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast"); std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast");
/** /**
...@@ -25,17 +41,191 @@ std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadca ...@@ -25,17 +41,191 @@ std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadca
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. ** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg. ** the remaining axes in shape must be the same as the shape of arg.
**/ **/
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor, Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const vector<size_t>& broadcast_axes) const vector<size_t>& broadcast_axes)
{ {
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes); return make_shared<BroadcastCall>(tensor, shape, broadcast_axes);
} }
void BroadcastCall::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = arg_type->as_tensor_view_type();
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
std::shared_ptr<BuiltinOp> CeilingCall::s_op = make_shared<BuiltinOp>("ceiling");
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<CeilingCall>(arg0, arg1);
}
// 'concatenate',
// 'constant',
// 'convert',
// 'convolution',
std::shared_ptr<BuiltinOp> DivideCall::s_op = make_shared<BuiltinOp>("divide");
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DivideCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot"); std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DotCall>(arg0, arg1); return make_shared<DotCall>(arg0, arg1);
} }
void DotCall::propagate_types()
{
auto arg0_tensor_type = m_arguments.at(0)->type()->as_tensor_view_type();
auto arg1_tensor_type = m_arguments.at(1)->type()->as_tensor_view_type();
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
std::shared_ptr<BuiltinOp> ExponentialCall::s_op = make_shared<BuiltinOp>("exponential");
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialCall>(arg0);
}
std::shared_ptr<BuiltinOp> FloorCall::s_op = make_shared<BuiltinOp>("floor");
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<FloorCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> LogCall::s_op = make_shared<BuiltinOp>("log");
Node::ptr ngraph::op::log(const Node::ptr& arg0)
{
return make_shared<LogCall>(arg0);
}
std::shared_ptr<BuiltinOp> MaximumCall::s_op = make_shared<BuiltinOp>("maximum");
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MaximumCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> MinimumCall::s_op = make_shared<BuiltinOp>("minimum");
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MinimumCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> MultiplyCall::s_op = make_shared<BuiltinOp>("multiply");
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MultiplyCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> NegateCall::s_op = make_shared<BuiltinOp>("negate");
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{
return make_shared<NegateCall>(arg0);
}
// 'pad',
// 'parameter',
std::shared_ptr<BuiltinOp> PowerCall::s_op = make_shared<BuiltinOp>("power");
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<PowerCall>(arg0, arg1);
}
//'reduce',
std::shared_ptr<BuiltinOp> RemainderCall::s_op = make_shared<BuiltinOp>("remainder");
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<RemainderCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> ReshapeCall::s_op = make_shared<BuiltinOp>("reshape");
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{
return make_shared<ReshapeCall>(arg0, shape);
}
//'reverse',
//'rng',
// 'select',
//'slice',
std::shared_ptr<BuiltinOp> SubtractCall::s_op = make_shared<BuiltinOp>("subtract");
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<SubtractCall>(arg0, arg1);
}
// 'transpose',
//'tuple',
// 'while'
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(DISABLED_graph, build_simple) TEST(ngraph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto cluster_0 = make_shared<Function>(4); auto cluster_0 = make_shared<Function>(4);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment