Commit 5f724e48 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #69 from NervanaSystems/cyphers/literal

Cyphers/literal
parents 8827be11 6a11672a
......@@ -36,6 +36,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-conversion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-potentially-evaluated-expression") # Triggers false alarms on typeid
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion")
......
......@@ -18,10 +18,17 @@ set (SRC
tree.cpp
util.cpp
log.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/convert.cpp
ops/constant.cpp
ops/dot.cpp
ops/function.cpp
ops/op.cpp
ops/parameter.cpp
ops/tuple.cpp
types/element_type.cpp
types/type.cpp
)
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include <set>
// Names for types that aren't worth giving their own classes
namespace ngraph
{
class Node;
class Parameter;
/// Zero or more nodes
using Nodes = std::vector<std::shared_ptr<Node>>;
/// A set of indices, for example, reduction axes
using IndexSet = std::set<size_t>;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
}
......@@ -20,6 +20,7 @@
#include <map>
#include <string>
#include <type_traits>
namespace ngraph
{
......@@ -49,12 +50,34 @@ namespace ngraph
const std::string m_cname;
};
const Type float32_t = Type(32, true, true, "float");
const Type int8_t = Type(8, false, true, "int8_t");
const Type int32_t = Type(32, false, true, "int32_t");
const Type int64_t = Type(64, false, true, "int64_t");
const Type uint8_t = Type(8, false, false, "int8_t");
const Type uint32_t = Type(32, false, false, "int32_t");
const Type uint64_t = Type(64, false, false, "int64_t");
// Literals (and probably other things we don't know about yet) need to have their C++ types
// and element types coordinated. Every element type corresponds to a TraitedType which provides
// access to both the instance and the C++ type used to hold the value during compilation.
template <typename T>
class TraitedType : public Type
{
public:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
// This is a reference to an instance of this element type.
static const TraitedType<T>& type;
TraitedType(const std::string& cname)
: Type(sizeof(T) * 8,
std::is_floating_point<T>::value,
std::is_signed<T>::value,
cname)
{
}
};
// Human-readable names for the element types
using Float = TraitedType<float>;
using Int8 = TraitedType<int8_t>;
using Int32 = TraitedType<int32_t>;
using Int64 = TraitedType<int64_t>;
using UInt8 = TraitedType<uint8_t>;
using UInt32 = TraitedType<uint32_t>;
using UInt64 = TraitedType<uint64_t>;
}
}
......@@ -14,10 +14,10 @@
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp"
#include "node.hpp"
#include "op.hpp"
#include "ops/parameter.hpp"
#include "type.hpp"
namespace ngraph
{
......
......@@ -18,11 +18,18 @@
#pragma once
#include "ngraph/element_type.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
#include "common.hpp"
#include "element_type.hpp"
#include "except.hpp"
#include "function.hpp"
#include "node.hpp"
#include "op.hpp"
#include "ops/broadcast.hpp"
#include "ops/concatenate.hpp"
#include "ops/constant.hpp"
#include "ops/convert.hpp"
#include "ops/dot.hpp"
#include "ops/parameter.hpp"
#include "ops/tuple.hpp"
#include "shape.hpp"
#include "type.hpp"
......@@ -20,7 +20,8 @@
#include <iostream>
#include "ngraph/type.hpp"
#include "type.hpp"
#include "common.hpp"
namespace ngraph
{
......@@ -37,7 +38,7 @@ namespace ngraph
using ptr = std::shared_ptr<Node>;
protected:
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
Node(const Nodes& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type)
, m_arguments(arguments)
{
......@@ -57,7 +58,7 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
const std::vector<Node::ptr>& arguments() const { return m_arguments; }
const Nodes& arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; }
......@@ -75,7 +76,7 @@ namespace ngraph
}
protected:
std::vector<Node::ptr> m_arguments;
Nodes m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
};
......
......@@ -16,9 +16,9 @@
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/parameter.hpp"
#include "ngraph/type.hpp"
#include "node.hpp"
#include "ops/parameter.hpp"
#include "type.hpp"
namespace ngraph
{
......@@ -26,26 +26,20 @@ namespace ngraph
{
Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
//Node::ptr candidate();
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr concatenate();
//Node::ptr constant();
//Node::ptr convert();
//Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get();
//Node::ptr get_tuple_element();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr greater_equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr less_equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr log(const Node::ptr& arg0);
//Node::ptr logical();
//Node::ptr logical(); and, or, not
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
......@@ -53,15 +47,16 @@ namespace ngraph
//Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce();
// Node::ptr reduce_window();
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape);
//Node::ptr reverse();
//Node::ptr rng();
//Node::ptr select();
//Node::ptr select_scatter();
//Node::ptr slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr transpose();
//Node::ptr tuple();
//Node::ptr while();
}
......@@ -102,7 +97,7 @@ namespace ngraph
/// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
// TODO: Implement for each op
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected:
......@@ -135,30 +130,6 @@ namespace ngraph
//virtual void propagate_types() override;
};
class BroadcastOp : public BuiltinOp
{
public:
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
};
class CeilingOp : public BuiltinOp
{
public:
......@@ -183,19 +154,6 @@ namespace ngraph
//virtual void propagate_types() override;
};
class DotOp : public BuiltinOp
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override;
};
class EqualOp : public BuiltinOp
{
public:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
class BroadcastOp : public BuiltinOp
{
public:
using Axes = std::vector<size_t>;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, const Axes& broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
Axes m_broadcast_axes;
};
namespace op
{
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
Node::ptr concatenate(const Nodes& args);
}
class ConcatenateOp : public BuiltinOp
{
public:
ConcatenateOp(const Nodes& args)
: BuiltinOp(args)
{
}
virtual std::string op_name() const override { return "concatenate"; }
virtual void propagate_types() override;
};
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "../element_type.hpp"
namespace ngraph
{
// Defines methods to all constant scalars
class ScalarConstantBaseOp : public Node
{
protected:
ScalarConstantBaseOp(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
virtual void propagate_types() override;
};
// Implement a constant scalar for each element type.
// The static make method takes a
template <typename T>
class ScalarConstantOp : public ScalarConstantBaseOp
{
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using ctype = typename T::ctype;
ScalarConstantOp(typename T::ctype value)
: ScalarConstantBaseOp(std::make_shared<TensorViewType>(T::type, Shape{}))
, m_value(value)
{
}
virtual std::string description() const override { return "ConstantScalar"; }
typename T::ctype value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarConstantOp<T>> make(U value)
{
return std::make_shared<ScalarConstantOp<T>>(value);
}
protected:
typename T::ctype m_value;
};
using FloatScalarConstantOp = ScalarConstantOp<element::Float>;
using Int8ScalarConstantOp = ScalarConstantOp<element::Int8>;
using Int32ScalarConstantOp = ScalarConstantOp<element::Int32>;
using Int64ScalarConstantOp = ScalarConstantOp<element::Int64>;
using UInt8ScalarConstantOp = ScalarConstantOp<element::UInt8>;
using UInt32ScalarConstantOp = ScalarConstantOp<element::UInt32>;
using UInt64ScalarConstantOp = ScalarConstantOp<element::UInt64>;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
class ConvertOp : public BuiltinOp
{
public:
ConvertOp(const Node::ptr& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg})
, m_element_type(element_type)
{
}
virtual std::string op_name() const override { return "convert"; }
virtual void propagate_types() override;
protected:
const ngraph::element::Type& m_element_type;
};
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const Node::ptr& arg, const ngraph::element::Type& element_type);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
class DotOp : public BuiltinOp
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override;
};
namespace op
{
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
}
}
......@@ -14,8 +14,8 @@
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
#include "../node.hpp"
#include "../type.hpp"
namespace ngraph
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
Node::ptr tuple(const Nodes& args);
}
class TupleOp : public BuiltinOp
{
public:
TupleOp(const Nodes& args)
: BuiltinOp(args)
{
}
virtual std::string op_name() const override { return "tuple"; }
virtual void propagate_types() override;
};
}
......@@ -17,8 +17,8 @@
#include <memory>
#include <vector>
#include "ngraph/element_type.hpp"
#include "ngraph/shape.hpp"
#include "element_type.hpp"
#include "shape.hpp"
namespace ngraph
{
......@@ -39,6 +39,8 @@ namespace ngraph
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
};
/**
......@@ -65,6 +67,8 @@ namespace ngraph
const element::Type& element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override;
protected:
const element::Type& m_element_type;
Shape m_shape;
......@@ -97,6 +101,8 @@ namespace ngraph
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override;
protected:
std::vector<ValueType::ptr> m_element_types;
};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
void ConcatenateOp::propagate_types()
{
throw ngraph_error("NIY");
}
Node::ptr op::concatenate(const std::vector<Node::ptr>& args)
{
return make_shared<ConcatenateOp>(args);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace ngraph;
void ScalarConstantBaseOp::propagate_types() {}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
void ConvertOp::propagate_types()
{
throw ngraph_error("NIY");
}
shared_ptr<ConvertOp> op::convert(const Node::ptr& arg, const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
......@@ -29,52 +29,11 @@ Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<AddOp>(arg0, arg1);
}
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
// 'concatenate',
// 'constant',
// 'convert',
// 'convolution',
......@@ -83,50 +42,6 @@ Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
return make_shared<DivideOp>(arg0, arg1);
}
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialOp>(arg0);
......@@ -163,7 +78,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
}
// 'pad',
// 'parameter',
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
......@@ -193,5 +107,4 @@ Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
}
// 'transpose',
//'tuple',
// 'while'
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
void TupleOp::propagate_types()
{
throw ngraph_error("NIY");
}
Node::ptr op::tuple(const std::vector<Node::ptr>& args)
{
return make_shared<TupleOp>(args);
}
......@@ -17,6 +17,8 @@
#include "ngraph/element_type.hpp"
using namespace ngraph;
std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
ngraph::element::Type::Type(size_t bitwidth,
......@@ -46,3 +48,28 @@ size_t ngraph::element::Type::size() const
{
return std::ceil((float)m_bitwidth / 8.0);
}
namespace
{
const element::Float s_float32_t = element::Float{"float"};
const element::Int8 s_int8_t = element::Int8{"int8_t"};
const element::Int32 s_int32_t = element::Int32{"int32_t"};
const element::Int64 s_int64_t = element::Int64{"int64_t"};
const element::UInt8 s_uint8_t = element::UInt8{"uint8_t"};
const element::UInt32 s_uint32_t = element::UInt32{"uint32_t"};
const element::UInt64 s_uint64_t = element::UInt64{"uint64_t"};
}
template <>
const element::TraitedType<float>& element::TraitedType<float>::type = s_float32_t;
template <>
const element::TraitedType<int8_t>& element::TraitedType<int8_t>::type = s_int8_t;
template <>
const element::TraitedType<int32_t>& element::TraitedType<int32_t>::type = s_int32_t;
template <>
const element::TraitedType<int64_t>& element::TraitedType<int64_t>::type = s_int64_t;
template <>
const element::TraitedType<uint8_t>& element::TraitedType<uint8_t>::type = s_uint8_t;
template <>
const element::TraitedType<uint32_t>& element::TraitedType<uint32_t>::type = s_uint32_t;
template <>
const element::TraitedType<uint64_t>& element::TraitedType<uint64_t>::type = s_uint64_t;
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <ngraph/ngraph.hpp>
using namespace std;
using namespace ngraph;
bool TensorViewType::operator==(const ValueType::ptr& that) const
{
auto that_tvt = dynamic_pointer_cast<TensorViewType>(that);
if (nullptr == that_tvt)
{
return false;
}
if (that_tvt->element_type() != m_element_type)
{
return false;
}
if (that_tvt->shape() != m_shape)
{
return false;
}
return true;
}
bool TupleType::operator==(const ValueType::ptr& that) const
{
auto that_tvt = dynamic_pointer_cast<TupleType>(that);
if (nullptr == that_tvt)
{
return false;
}
return that_tvt->element_types() == element_types();
}
......@@ -22,10 +22,10 @@ using namespace ngraph;
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = op::parameter(element::float32_t, {7, 3});
auto arg1 = op::parameter(element::float32_t, {3});
auto arg2 = op::parameter(element::float32_t, {32, 7});
auto arg3 = op::parameter(element::float32_t, {32, 7});
auto arg0 = op::parameter(element::Float::type, {7, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2);
......@@ -40,7 +40,7 @@ TEST(build_graph, build_simple)
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5});
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::type, Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
......@@ -57,14 +57,14 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto arg0 = op::parameter(element::float32_t, {32, 3});
auto arg1 = op::parameter(element::float32_t, {3});
auto arg2 = op::parameter(element::float32_t, {32});
auto arg0 = op::parameter(element::Float::type, {32, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32});
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto parg = op::parameter(element::float32_t, {});
auto parg = op::parameter(element::Float::type, {});
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
......@@ -72,5 +72,28 @@ TEST(build_graph, node_comparison)
ASSERT_FALSE(pattern_dot->is_same_op_type(add));
}
TEST(build_graph, literal)
{
// float scalar from a float
auto float0 = FloatScalarConstantOp::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type);
auto d = op::dot(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int
auto float1 = FloatScalarConstantOp::make(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type);
auto int32_0 = Int32ScalarConstantOp::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type);
}
// Check argument inverses
TEST(build_graph, arg_inverse) {}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment