Commit 807ad951 authored by Scott Cyphers's avatar Scott Cyphers

Add scalar literals.

Element types needed some reorganization to allow literals to know what kind of values to hold.
parent 8827be11
...@@ -19,9 +19,11 @@ set (SRC ...@@ -19,9 +19,11 @@ set (SRC
util.cpp util.cpp
log.cpp log.cpp
ops/function.cpp ops/function.cpp
ops/literal.cpp
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp
) )
# NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled # NOTE: We'd prefer to only have the .cpp files *in* the 'transformers' directory be compiled
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <type_traits>
namespace ngraph namespace ngraph
{ {
...@@ -49,12 +50,34 @@ namespace ngraph ...@@ -49,12 +50,34 @@ namespace ngraph
const std::string m_cname; const std::string m_cname;
}; };
const Type float32_t = Type(32, true, true, "float"); // Literals (and probably other things we don't know about yet) need to have their C++ types
const Type int8_t = Type(8, false, true, "int8_t"); // and element types coordinated. Every element type corresponds to a TraitedType which provides
const Type int32_t = Type(32, false, true, "int32_t"); // access to both the instance and the C++ type used to hold the value during compilation.
const Type int64_t = Type(64, false, true, "int64_t"); template <typename T>
const Type uint8_t = Type(8, false, false, "int8_t"); class TraitedType : public Type
const Type uint32_t = Type(32, false, false, "int32_t"); {
const Type uint64_t = Type(64, false, false, "int64_t"); public:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
// This is a reference to an instance of this element type.
static const TraitedType<T>& type;
TraitedType(const std::string& cname)
: Type(sizeof(T) * 8,
std::is_floating_point<T>::value,
std::is_signed<T>::value,
cname)
{
}
};
// Human-readable names for the element types
using Float = TraitedType<float>;
using Int8 = TraitedType<int8_t>;
using Int32 = TraitedType<int32_t>;
using Int64 = TraitedType<int64_t>;
using UInt8 = TraitedType<uint8_t>;
using UInt32 = TraitedType<uint32_t>;
using UInt64 = TraitedType<uint64_t>;
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/element_type.hpp"
namespace ngraph
{
// Defines methods to all literal scalars
class ScalarLiteralBaseOp : public Node
{
protected:
ScalarLiteralBaseOp(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
virtual void propagate_types() override;
};
// Implement a literal scalar for each element type.
// The static make method takes a
template <typename T>
class ScalarLiteralOp : public ScalarLiteralBaseOp
{
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using ctype = typename T::ctype;
ScalarLiteralOp(typename T::ctype value)
: ScalarLiteralBaseOp(std::make_shared<TensorViewType>(T::type, ngraph::Shape{}))
, m_value(value)
{
}
virtual std::string description() const override { return "LiteralScalar"; }
typename T::ctype value() const { return m_value; }
// Make a literal from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarLiteralOp<T>> make(U value)
{
return std::make_shared<ScalarLiteralOp<T>>(
static_cast<ScalarLiteralOp<T>::ctype>(value));
}
protected:
typename T::ctype m_value;
};
using FloatScalarOp = ScalarLiteralOp<element::Float>;
using Int8ScalarOp = ScalarLiteralOp<element::Int8>;
using Int32ScalarOp = ScalarLiteralOp<element::Int32>;
using Int64ScalarOp = ScalarLiteralOp<element::Int64>;
using UInt8ScalarOp = ScalarLiteralOp<element::UInt8>;
using UInt32ScalarOp = ScalarLiteralOp<element::UInt32>;
using UInt64ScalarOp = ScalarLiteralOp<element::UInt64>;
}
\ No newline at end of file
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/literal.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/parameter.hpp" #include "ngraph/parameter.hpp"
......
...@@ -39,6 +39,8 @@ namespace ngraph ...@@ -39,6 +39,8 @@ namespace ngraph
using ptr = std::shared_ptr<ValueType>; using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); }
}; };
/** /**
...@@ -65,6 +67,8 @@ namespace ngraph ...@@ -65,6 +67,8 @@ namespace ngraph
const element::Type& element_type() const { return m_element_type; } const element::Type& element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; } const Shape& shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override;
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
Shape m_shape; Shape m_shape;
...@@ -97,6 +101,8 @@ namespace ngraph ...@@ -97,6 +101,8 @@ namespace ngraph
const std::vector<ValueType::ptr> element_types() const { return m_element_types; } const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; } std::vector<ValueType::ptr> element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override;
protected: protected:
std::vector<ValueType::ptr> m_element_types; std::vector<ValueType::ptr> m_element_types;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace ngraph;
void ScalarLiteralBaseOp::propagate_types() {}
\ No newline at end of file
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "ngraph/element_type.hpp" #include "ngraph/element_type.hpp"
using namespace ngraph;
std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list; std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
ngraph::element::Type::Type(size_t bitwidth, ngraph::element::Type::Type(size_t bitwidth,
...@@ -46,3 +48,28 @@ size_t ngraph::element::Type::size() const ...@@ -46,3 +48,28 @@ size_t ngraph::element::Type::size() const
{ {
return std::ceil((float)m_bitwidth / 8.0); return std::ceil((float)m_bitwidth / 8.0);
} }
namespace
{
const element::Float s_float32_t = element::Float{"float"};
const element::Int8 s_int8_t = element::Int8{"int8_t"};
const element::Int32 s_int32_t = element::Int32{"int32_t"};
const element::Int64 s_int64_t = element::Int64{"int64_t"};
const element::UInt8 s_uint8_t = element::UInt8{"uint8_t"};
const element::UInt32 s_uint32_t = element::UInt32{"uint32_t"};
const element::UInt64 s_uint64_t = element::UInt64{"uint64_t"};
}
template <>
const element::TraitedType<float>& element::TraitedType<float>::type = s_float32_t;
template <>
const element::TraitedType<int8_t>& element::TraitedType<int8_t>::type = s_int8_t;
template <>
const element::TraitedType<int32_t>& element::TraitedType<int32_t>::type = s_int32_t;
template <>
const element::TraitedType<int64_t>& element::TraitedType<int64_t>::type = s_int64_t;
template <>
const element::TraitedType<uint8_t>& element::TraitedType<uint8_t>::type = s_uint8_t;
template <>
const element::TraitedType<uint32_t>& element::TraitedType<uint32_t>::type = s_uint32_t;
template <>
const element::TraitedType<uint64_t>& element::TraitedType<uint64_t>::type = s_uint64_t;
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <ngraph/ngraph.hpp>
using namespace std;
using namespace ngraph;
bool TensorViewType::operator==(const ValueType::ptr& that) const
{
auto that_tvt = dynamic_pointer_cast<TensorViewType>(that);
if (nullptr == that_tvt)
{
return false;
}
if (that_tvt->element_type() != m_element_type)
{
return false;
}
if (that_tvt->shape() != m_shape)
{
return false;
}
return true;
}
bool TupleType::operator==(const ValueType::ptr& that) const
{
auto that_tvt = dynamic_pointer_cast<TupleType>(that);
if (nullptr == that_tvt)
{
return false;
}
return that_tvt->element_types() == element_types();
}
...@@ -22,10 +22,10 @@ using namespace ngraph; ...@@ -22,10 +22,10 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = op::parameter(element::float32_t, {7, 3}); auto arg0 = op::parameter(element::Float::type, {7, 3});
auto arg1 = op::parameter(element::float32_t, {3}); auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::float32_t, {32, 7}); auto arg2 = op::parameter(element::Float::type, {32, 7});
auto arg3 = op::parameter(element::float32_t, {32, 7}); auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
...@@ -40,7 +40,7 @@ TEST(build_graph, build_simple) ...@@ -40,7 +40,7 @@ TEST(build_graph, build_simple)
TEST(build_graph, as_type) TEST(build_graph, as_type)
{ {
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5}); ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::type, Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt); auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv); ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt); auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
...@@ -57,14 +57,14 @@ TEST(build_graph, as_type) ...@@ -57,14 +57,14 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = op::parameter(element::float32_t, {32, 3}); auto arg0 = op::parameter(element::Float::type, {32, 3});
auto arg1 = op::parameter(element::float32_t, {3}); auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::float32_t, {32}); auto arg2 = op::parameter(element::Float::type, {32});
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto parg = op::parameter(element::float32_t, {}); auto parg = op::parameter(element::Float::type, {});
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
...@@ -72,5 +72,28 @@ TEST(build_graph, node_comparison) ...@@ -72,5 +72,28 @@ TEST(build_graph, node_comparison)
ASSERT_FALSE(pattern_dot->is_same_op_type(add)); ASSERT_FALSE(pattern_dot->is_same_op_type(add));
} }
TEST(build_graph, literal)
{
// float scalar from a float
auto float0 = FloatScalarOp::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::type, Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type);
auto d = op::dot(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int
auto float1 = FloatScalarOp::make(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type);
auto int32_0 = Int32ScalarOp::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::type, Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type);
}
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) {} TEST(build_graph, arg_inverse) {}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment