Unverified Commit 1f39edbe authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge branch 'master' into gwenger/deprecate_copy_from

parents d739eea6 ef129a77
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
set (SRC set (SRC
assertion.hpp assertion.hpp
attribute_adapter.cpp
attribute_adapter.hpp
attribute_visitor.hpp
autodiff/adjoints.cpp autodiff/adjoints.cpp
autodiff/adjoints.hpp autodiff/adjoints.hpp
axis_set.cpp axis_set.cpp
...@@ -77,7 +80,10 @@ set (SRC ...@@ -77,7 +80,10 @@ set (SRC
dimension.hpp dimension.hpp
distributed.cpp distributed.cpp
distributed.hpp distributed.hpp
enum_names.hpp
except.hpp except.hpp
factory.cpp
factory.hpp
file_util.cpp file_util.cpp
file_util.hpp file_util.hpp
function.cpp function.cpp
...@@ -351,6 +357,8 @@ set (SRC ...@@ -351,6 +357,8 @@ set (SRC
op/fused/gru_cell.hpp op/fused/gru_cell.hpp
op/fused/layer_norm.cpp op/fused/layer_norm.cpp
op/fused/layer_norm.hpp op/fused/layer_norm.hpp
op/fused/log_softmax.cpp
op/fused/log_softmax.hpp
op/fused/lstm_cell.cpp op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp op/fused/lstm_cell.hpp
op/fused/lstm_sequence.cpp op/fused/lstm_sequence.cpp
...@@ -555,6 +563,7 @@ set (SRC ...@@ -555,6 +563,7 @@ set (SRC
type/float16.cpp type/float16.cpp
type/float16.hpp type/float16.hpp
type/element_type.cpp type/element_type.cpp
type.hpp
util.cpp util.cpp
util.hpp util.hpp
validation_util.cpp validation_util.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <vector>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp"
using namespace std;
using namespace ngraph;
namespace
{
template <typename A, typename B>
A copy_from(B& b)
{
A result(b.size());
for (int i = 0; i < b.size(); ++i)
{
result[i] = b[i];
}
return result;
}
}
namespace ngraph
{
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<float>::type_info;
const double& AttributeAdapter<float>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<float>::set(const double& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<double>::type_info;
const double& AttributeAdapter<double>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<double>::set(const double& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int8_t>::type_info;
const int64_t& AttributeAdapter<int8_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int8_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int16_t>::type_info;
const int64_t& AttributeAdapter<int16_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int16_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int32_t>::type_info;
const int64_t& AttributeAdapter<int32_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int32_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<int64_t>::type_info;
const int64_t& AttributeAdapter<int64_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<int64_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint8_t>::type_info;
const int64_t& AttributeAdapter<uint8_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint8_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint16_t>::type_info;
const int64_t& AttributeAdapter<uint16_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint16_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint32_t>::type_info;
const int64_t& AttributeAdapter<uint32_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint32_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<uint64_t>::type_info;
const int64_t& AttributeAdapter<uint64_t>::get()
{
if (!m_buffer_valid)
{
m_buffer = m_value;
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<uint64_t>::set(const int64_t& value)
{
m_value = value;
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<vector<int64_t>>::type_info;
const vector<int64_t>& AttributeAdapter<vector<int64_t>>::get() { return m_value; }
void AttributeAdapter<vector<int64_t>>::set(const vector<int64_t>& value) { m_value = value; }
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<vector<uint64_t>>::type_info;
const vector<int64_t>& AttributeAdapter<vector<uint64_t>>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<vector<uint64_t>>::set(const vector<int64_t>& value)
{
m_value = copy_from<vector<uint64_t>>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Shape>::type_info;
const vector<int64_t>& AttributeAdapter<Shape>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Shape>::set(const vector<int64_t>& value)
{
m_value = copy_from<Shape>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<Strides>::type_info;
const vector<int64_t>& AttributeAdapter<Strides>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<int64_t>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<Strides>::set(const vector<int64_t>& value)
{
m_value = copy_from<Strides>(value);
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<AxisSet>::type_info;
const vector<int64_t>& AttributeAdapter<AxisSet>::get()
{
if (!m_buffer_valid)
{
for (auto elt : m_value)
{
m_buffer.push_back(elt);
}
}
return m_buffer;
}
void AttributeAdapter<AxisSet>::set(const vector<int64_t>& value)
{
m_value = AxisSet();
for (auto elt : value)
{
m_value.insert(elt);
}
m_buffer_valid = false;
}
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<PartialShape>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
NGRAPH_API constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastSpec>::type_info;
}
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <utility>
#include "ngraph/partial_shape.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
template <typename T>
class ValueAccessor;
/// \brief Visits the attributes of a node.
///
/// Attributes are the values set when building a graph which are not
/// computed as the graph executes. Values computed from the graph topology and attributes
/// during compilation are not attributes.
class AttributeVisitor
{
public:
virtual ~AttributeVisitor() {}
// Must implement these methods
virtual void on_attribute(const std::string& name, std::string& value) = 0;
virtual void on_attribute(const std::string& name, bool& value) = 0;
virtual void on_adapter(const std::string& name, ValueAccessor<void>& adapter) = 0;
// The remaining adapter methods fall back on the void adapter if not implemented
virtual void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
};
virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
// Use an adapter for non-primitive types
template <typename T>
// typename std::enable_if<std::is_class<T>::value, void>::type
void on_attribute(const std::string& name, T& value)
{
AttributeAdapter<T> adapter(value);
on_adapter(name, adapter);
}
};
}
...@@ -19,28 +19,30 @@ ...@@ -19,28 +19,30 @@
#include "ngraph/distributed/null.hpp" #include "ngraph/distributed/null.hpp"
#include "ngraph/distributed/open_mpi.hpp" #include "ngraph/distributed/open_mpi.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/type.hpp"
using namespace ngraph; using namespace ngraph;
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj) namespace ngraph
{ {
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) template <>
#pragma GCC diagnostic push EnumNames<reduction::Type>& EnumNames<reduction::Type>::get()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (obj)
{ {
case reduction::Type::SUM: out << "SUM"; break; static auto enum_names = EnumNames<reduction::Type>("reduction::Type",
case reduction::Type::PROD: out << "PROD"; break; {{"SUM", reduction::Type::SUM},
case reduction::Type::MIN: out << "MIN"; break; {"PROD", reduction::Type::PROD},
case reduction::Type::MAX: out << "MAX"; break; {"MIN", reduction::Type::MIN},
{"MAX", reduction::Type::MAX}});
return enum_names;
} }
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop constexpr DiscreteTypeInfo AttributeAdapter<reduction::Type>::type_info;
#endif }
return out;
}; std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{
return out << as_string(obj);
}
static std::unique_ptr<DistributedInterface> s_distributed_interface; static std::unique_ptr<DistributedInterface> s_distributed_interface;
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
...@@ -37,6 +39,20 @@ namespace ngraph ...@@ -37,6 +39,20 @@ namespace ngraph
std::ostream& operator<<(std::ostream& out, const Type& obj); std::ostream& operator<<(std::ostream& out, const Type& obj);
} }
template <>
class AttributeAdapter<reduction::Type> : public EnumAttributeAdapterBase<reduction::Type>
{
public:
AttributeAdapter(reduction::Type& value)
: EnumAttributeAdapterBase<reduction::Type>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<reduction::Type>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
class DistributedInterface class DistributedInterface
{ {
public: public:
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <utility>
#include "ngraph/check.hpp"
namespace ngraph
{
/// Uses a pairings defined by EnumTypes::get() to convert between strings
/// and enum values.
template <typename EnumType>
class EnumNames
{
public:
/// Converts strings to enum values
static EnumType as_enum(const std::string& name)
{
for (auto p : get().m_string_enums)
{
if (p.first == name)
{
return p.second;
}
}
NGRAPH_CHECK(false, "\"", name, "\"", " is not a member of enum ", get().m_enum_name);
}
/// Converts enum values to strings
static const std::string& as_string(EnumType e)
{
for (auto& p : get().m_string_enums)
{
if (p.second == e)
{
return p.first;
}
}
NGRAPH_CHECK(false, " invalid member of enum ", get().m_enum_name);
}
private:
/// Creates the mapping.
EnumNames(const std::string& enum_name,
const std::vector<std::pair<std::string, EnumType>> string_enums)
: m_enum_name(enum_name)
, m_string_enums(string_enums)
{
}
/// Must be defined to returns a singleton for each supported enum class
static EnumNames<EnumType>& get();
const std::string m_enum_name;
std::vector<std::pair<std::string, EnumType>> m_string_enums;
};
/// Returns the enum value matching the string
template <typename Type, typename Value>
typename std::enable_if<std::is_convertible<Value, std::string>::value, Type>::type
as_enum(const Value& value)
{
return EnumNames<Type>::as_enum(value);
}
/// Returns the string matching the enum value
template <typename Value>
const std::string& as_string(Value value)
{
return EnumNames<Value>::as_string(value);
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <mutex>
#include "ngraph/factory.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/parameter.hpp"
using namespace std;
namespace ngraph
{
mutex& get_registry_mutex()
{
static mutex registry_mutex;
return registry_mutex;
}
template <>
FactoryRegistry<Node>& FactoryRegistry<Node>::get()
{
static FactoryRegistry<Node> registry;
static mutex init_guard;
// TODO: Add a lock
if (registry.m_factory_map.size() == 0)
{
lock_guard<mutex> guard(init_guard);
if (registry.m_factory_map.size() == 0)
{
registry.register_factory<op::Abs>();
registry.register_factory<op::Acos>();
registry.register_factory<op::v0::Add>();
registry.register_factory<op::v1::Add>();
registry.register_factory<op::All>();
registry.register_factory<op::AllReduce>();
registry.register_factory<op::And>();
registry.register_factory<op::Any>();
registry.register_factory<op::ArgMax>();
registry.register_factory<op::ArgMin>();
registry.register_factory<op::v0::AvgPool>();
registry.register_factory<op::v0::AvgPoolBackprop>();
registry.register_factory<op::v1::AvgPool>();
registry.register_factory<op::v1::AvgPoolBackprop>();
registry.register_factory<op::BatchNormInference>();
registry.register_factory<op::BatchNormTraining>();
registry.register_factory<op::BatchNormTrainingBackprop>();
registry.register_factory<op::BroadcastDistributed>();
registry.register_factory<op::v0::Broadcast>();
registry.register_factory<op::v0::BroadcastLike>();
registry.register_factory<op::v1::Broadcast>();
registry.register_factory<op::Ceiling>();
registry.register_factory<op::Concat>();
registry.register_factory<op::v1::LogicalAnd>();
registry.register_factory<op::Parameter>();
}
}
return registry;
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <map>
#include <mutex>
namespace ngraph
{
std::mutex& get_registry_mutex();
template <typename T>
class FactoryRegistry
{
public:
using Factory = std::function<T*()>;
using FactoryMap = std::map<decltype(T::type_info), Factory>;
template <typename U>
void register_factory()
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[U::type_info] = []() { return new U(); };
}
bool has_factory(const decltype(T::type_info) & info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end();
}
T* create(const decltype(T::type_info) & info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(info);
return it == m_factory_map.end() ? nullptr : it->second();
}
static FactoryRegistry<T>& get();
protected:
// Need a Compare on type_info
FactoryMap m_factory_map;
};
}
...@@ -112,6 +112,7 @@ add_library(onnx_import STATIC ...@@ -112,6 +112,7 @@ add_library(onnx_import STATIC
op/leaky_relu.hpp op/leaky_relu.hpp
op/less.hpp op/less.hpp
op/log.hpp op/log.hpp
op/log_softmax.cpp
op/log_softmax.hpp op/log_softmax.hpp
op/lp_norm.cpp op/lp_norm.cpp
op/lp_norm.hpp op/lp_norm.hpp
...@@ -121,7 +122,6 @@ add_library(onnx_import STATIC ...@@ -121,7 +122,6 @@ add_library(onnx_import STATIC
op/lrn.hpp op/lrn.hpp
op/lstm.cpp op/lstm.cpp
op/lstm.hpp op/lstm.hpp
op/matmul.cpp
op/matmul.hpp op/matmul.hpp
op/matmul_integer.cpp op/matmul_integer.cpp
op/matmul_integer.hpp op/matmul_integer.hpp
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "matmul.hpp" #include <memory>
#include "ngraph/builder/matmul_factory.hpp"
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -25,24 +29,16 @@ namespace ngraph ...@@ -25,24 +29,16 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector matmul(const Node& node) NodeVector log_softmax(const Node& node)
{ {
auto ng_inputs = node.get_ng_inputs(); NodeVector inputs{node.get_ng_inputs()};
auto factory = builder::MatmulFactory( auto data = inputs.at(0);
(OutputVector(std::begin(ng_inputs), std::end(ng_inputs)))); auto data_shape = data->get_shape();
std::size_t left_rank{ng_inputs.at(0)->get_shape().size()}; int axis = node.get_attribute_value<int64_t>("axis", 1);
std::size_t right_rank{ng_inputs.at(1)->get_shape().size()};
return {std::make_shared<ngraph::op::LogSoftmax>(data, axis)};
if (left_rank == 0 || right_rank == 0)
{
NGRAPH_WARN
<< (node) << " "
<< "ONNX standard doesn't allow scalar operands, however nGraph "
"accepts them. Consider use of element-wise multiplication instead "
"to conform with ONNX standard.";
}
return factory.make_matmul_op();
} }
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op
......
...@@ -19,9 +19,7 @@ ...@@ -19,9 +19,7 @@
#include <memory> #include <memory>
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/frontend/onnx_import/op/softmax.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/log.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,10 +29,7 @@ namespace ngraph ...@@ -31,10 +29,7 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
inline NodeVector log_softmax(const Node& node) NodeVector log_softmax(const Node& node);
{
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
} // namespace set_1 } // namespace set_1
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/fused/matmul.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -27,7 +28,11 @@ namespace ngraph ...@@ -27,7 +28,11 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector matmul(const Node& node); NodeVector matmul(const Node& node)
{
return {std::make_shared<ngraph::op::MatMul>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op
......
...@@ -61,6 +61,8 @@ namespace ngraph ...@@ -61,6 +61,8 @@ namespace ngraph
/// \brief Convenience functions that create addional graph nodes to implement commonly-used /// \brief Convenience functions that create addional graph nodes to implement commonly-used
/// recipes, for example auto-broadcast. /// recipes, for example auto-broadcast.
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/dequantize_builder.hpp" #include "ngraph/builder/dequantize_builder.hpp"
#include "ngraph/builder/numpy_transpose.hpp" #include "ngraph/builder/numpy_transpose.hpp"
...@@ -80,6 +82,7 @@ namespace ngraph ...@@ -80,6 +82,7 @@ namespace ngraph
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/dimension.hpp" #include "ngraph/dimension.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/factory.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/lambda.hpp" #include "ngraph/lambda.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -140,6 +143,7 @@ namespace ngraph ...@@ -140,6 +143,7 @@ namespace ngraph
#include "ngraph/op/fused/gru_cell.hpp" #include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp" #include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp" #include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp" #include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp" #include "ngraph/op/fused/matmul.hpp"
...@@ -221,4 +225,5 @@ namespace ngraph ...@@ -221,4 +225,5 @@ namespace ngraph
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
#include "ngraph/specialize_function.hpp" #include "ngraph/specialize_function.hpp"
#include "ngraph/type.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -50,6 +50,7 @@ namespace ngraph ...@@ -50,6 +50,7 @@ namespace ngraph
template <typename NodeType> template <typename NodeType>
class Output; class Output;
class AttributeVisitor;
class Variant; class Variant;
class Node; class Node;
using NodeVector = std::vector<std::shared_ptr<Node>>; using NodeVector = std::vector<std::shared_ptr<Node>>;
...@@ -110,13 +111,14 @@ namespace ngraph ...@@ -110,13 +111,14 @@ namespace ngraph
template <typename NodeType> template <typename NodeType>
friend class Output; friend class Output;
protected: public:
/// Throws if the node is invalid. /// Throws if the node is invalid.
virtual void validate_and_infer_types(); virtual void validate_and_infer_types();
// Called in constructors during transition // Called in constructors during transition
void constructor_validate_and_infer_types(); void constructor_validate_and_infer_types();
protected:
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args( std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec()); const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
void validate_and_infer_elementwise_arithmetic( void validate_and_infer_elementwise_arithmetic(
...@@ -157,6 +159,7 @@ namespace ngraph ...@@ -157,6 +159,7 @@ namespace ngraph
virtual ~Node(); virtual ~Node();
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
virtual bool is_unary_elementwise_arithmetic() const { return false; } virtual bool is_unary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_arithmetic() const { return false; } virtual bool is_binary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_comparison() const { return false; } virtual bool is_binary_elementwise_comparison() const { return false; }
......
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
Abs() = default; Abs() = default;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
......
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Acos(const Output<Node>& arg); Acos(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
......
...@@ -37,6 +37,12 @@ shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) con ...@@ -37,6 +37,12 @@ shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) con
return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob());
} }
bool op::v0::Add::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}
void op::v0::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
...@@ -70,6 +76,12 @@ op::v1::Add::Add(const Output<Node>& arg0, ...@@ -70,6 +76,12 @@ op::v1::Add::Add(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::Add::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseArithmetic::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -52,7 +52,7 @@ namespace ngraph ...@@ -52,7 +52,7 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -90,7 +90,7 @@ namespace ngraph ...@@ -90,7 +90,7 @@ namespace ngraph
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
protected: protected:
......
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const Output<Node>& arg, const Output<Node>& reduction_axes); All(const Output<Node>& arg, const Output<Node>& reduction_axes);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for All. /// \return The default value for All.
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/type.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -47,6 +49,12 @@ shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) c ...@@ -47,6 +49,12 @@ shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) c
return make_shared<AllReduce>(new_args.at(0), get_reduce_type()); return make_shared<AllReduce>(new_args.at(0), get_reduce_type());
} }
bool op::AllReduce::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("reduce_type", m_reduce_type);
return true;
}
reduction::Type op::AllReduce::get_reduce_type() const reduction::Type op::AllReduce::get_reduce_type() const
{ {
return m_reduce_type; return m_reduce_type;
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const; reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type); void set_reduce_type(reduction::Type reduce_type);
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
reduction::Type m_reduce_type{reduction::Type::SUM}; reduction::Type m_reduce_type{reduction::Type::SUM};
......
...@@ -29,6 +29,12 @@ op::v1::LogicalAnd::LogicalAnd(const Output<Node>& arg0, ...@@ -29,6 +29,12 @@ op::v1::LogicalAnd::LogicalAnd(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::LogicalAnd::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseLogical::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -45,6 +51,12 @@ op::v0::And::And(const Output<Node>& arg0, ...@@ -45,6 +51,12 @@ op::v0::And::And(const Output<Node>& arg0,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v0::And::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseLogical::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -53,7 +53,7 @@ namespace ngraph ...@@ -53,7 +53,7 @@ namespace ngraph
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
}; };
} // namespace v0 } // namespace v0
...@@ -85,7 +85,7 @@ namespace ngraph ...@@ -85,7 +85,7 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
}; };
} }
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
/// \return The default value for Any. /// \return The default value for Any.
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
......
...@@ -28,6 +28,12 @@ op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& in ...@@ -28,6 +28,12 @@ op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& in
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::ArgMax::visit_attributes(AttributeVisitor& visitor)
{
IndexReduction::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
} }
......
...@@ -28,6 +28,12 @@ op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& in ...@@ -28,6 +28,12 @@ op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& in
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::ArgMin::visit_attributes(AttributeVisitor& visitor)
{
IndexReduction::visit_attributes(visitor);
return true;
}
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
}; };
} }
......
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -78,6 +78,19 @@ op::v0::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -78,6 +78,19 @@ op::v0::AvgPool::AvgPool(const Output<Node>& arg,
{ {
} }
bool op::v0::AvgPool::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("window_shape", m_window_shape);
visitor.on_attribute("window_movement_strides", m_window_movement_strides);
visitor.on_attribute("padding_below", m_padding_below);
visitor.on_attribute("padding_above", m_padding_above);
visitor.on_attribute("include_padding_in_avg_computation",
m_include_padding_in_avg_computation);
visitor.on_attribute("pad_type", m_pad_type);
visitor.on_attribute("ceil_mode", m_ceil_mode);
return true;
}
void op::v0::AvgPool::validate_and_infer_types() void op::v0::AvgPool::validate_and_infer_types()
{ {
if (0 == m_window_movement_strides.size()) if (0 == m_window_movement_strides.size())
...@@ -251,6 +264,18 @@ op::v0::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -251,6 +264,18 @@ op::v0::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v0::AvgPoolBackprop::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("forward_arg_shape", m_forward_arg_shape);
visitor.on_attribute("window_shape", m_window_shape);
visitor.on_attribute("window_movement_strides", m_window_movement_strides);
visitor.on_attribute("padding_below", m_padding_below);
visitor.on_attribute("padding_above", m_padding_above);
visitor.on_attribute("include_padding_in_avg_computation",
m_include_padding_in_avg_computation);
return true;
}
void op::v0::AvgPoolBackprop::validate_and_infer_types() void op::v0::AvgPoolBackprop::validate_and_infer_types()
{ {
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
...@@ -420,6 +445,18 @@ op::v1::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -420,6 +445,18 @@ op::v1::AvgPool::AvgPool(const Output<Node>& arg,
{ {
} }
bool op::v1::AvgPool::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("kernel", m_kernel);
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("exclude_pad", m_exclude_pad);
visitor.on_attribute("auto_pad", m_auto_pad);
visitor.on_attribute("rounding_type", m_rounding_type);
return true;
}
void op::v1::AvgPool::validate_and_infer_types() void op::v1::AvgPool::validate_and_infer_types()
{ {
if (0 == m_strides.size()) if (0 == m_strides.size())
...@@ -575,6 +612,16 @@ op::v1::AvgPoolBackprop::AvgPoolBackprop(const Output<Node>& delta, ...@@ -575,6 +612,16 @@ op::v1::AvgPoolBackprop::AvgPoolBackprop(const Output<Node>& delta,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::AvgPoolBackprop::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("kernel", m_kernel);
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("exclude_pad", m_exclude_pad);
return true;
}
const Shape op::v1::AvgPoolBackprop::get_forward_arg_shape() const const Shape op::v1::AvgPoolBackprop::get_forward_arg_shape() const
{ {
Shape shape; Shape shape;
......
...@@ -130,6 +130,8 @@ namespace ngraph ...@@ -130,6 +130,8 @@ namespace ngraph
/// `[n]` /// `[n]`
AvgPool(const Output<Node>& arg, const Shape& window_shape); AvgPool(const Output<Node>& arg, const Shape& window_shape);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -187,6 +189,7 @@ namespace ngraph ...@@ -187,6 +189,7 @@ namespace ngraph
bool include_padding_in_avg_computation); bool include_padding_in_avg_computation);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -281,6 +284,7 @@ namespace ngraph ...@@ -281,6 +284,7 @@ namespace ngraph
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -337,6 +341,7 @@ namespace ngraph ...@@ -337,6 +341,7 @@ namespace ngraph
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -46,6 +46,12 @@ op::BatchNormTraining::BatchNormTraining(double eps, ...@@ -46,6 +46,12 @@ op::BatchNormTraining::BatchNormTraining(double eps,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BatchNormTraining::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormTraining::validate_and_infer_types() void op::BatchNormTraining::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
...@@ -129,6 +135,12 @@ op::BatchNormInference::BatchNormInference(double eps, ...@@ -129,6 +135,12 @@ op::BatchNormInference::BatchNormInference(double eps,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BatchNormInference::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormInference::validate_and_infer_types() void op::BatchNormInference::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
...@@ -191,6 +203,12 @@ op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, ...@@ -191,6 +203,12 @@ op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BatchNormTrainingBackprop::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("epsilon", m_epsilon);
return true;
}
void op::BatchNormTrainingBackprop::validate_and_infer_types() void op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
......
...@@ -43,6 +43,8 @@ namespace ngraph ...@@ -43,6 +43,8 @@ namespace ngraph
const Output<Node>& beta, const Output<Node>& beta,
double epsilon); double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
/// In this version of BatchNorm: /// In this version of BatchNorm:
/// ///
...@@ -108,6 +110,8 @@ namespace ngraph ...@@ -108,6 +110,8 @@ namespace ngraph
const Output<Node>& variance, const Output<Node>& variance,
double epsilon); double epsilon);
bool visit_attributes(AttributeVisitor& visitor) override;
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
/// In this version of BatchNorm: /// In this version of BatchNorm:
/// ///
...@@ -184,6 +188,7 @@ namespace ngraph ...@@ -184,6 +188,7 @@ namespace ngraph
const Output<Node>& delta); const Output<Node>& delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; } void set_eps_value(double epsilon) { m_epsilon = epsilon; }
......
...@@ -46,6 +46,12 @@ op::v1::Broadcast::Broadcast(const Output<Node>& arg, ...@@ -46,6 +46,12 @@ op::v1::Broadcast::Broadcast(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::Broadcast::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("broadcast_spec", m_broadcast_spec);
return true;
}
std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const
{ {
AxisSet broadcast_axes; AxisSet broadcast_axes;
...@@ -286,6 +292,13 @@ op::v0::Broadcast::Broadcast(const Output<Node>& arg, ...@@ -286,6 +292,13 @@ op::v0::Broadcast::Broadcast(const Output<Node>& arg,
{ {
} }
bool op::v0::Broadcast::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("shape", m_shape);
visitor.on_attribute("broadcast_axes", m_broadcast_axes);
return true;
}
void op::v0::Broadcast::validate_and_infer_types() void op::v0::Broadcast::validate_and_infer_types()
{ {
infer_shape(); infer_shape();
...@@ -355,6 +368,14 @@ op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg, ...@@ -355,6 +368,14 @@ op::v0::BroadcastLike::BroadcastLike(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v0::BroadcastLike::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("shape", m_shape);
visitor.on_attribute("broadcast_axes", m_broadcast_axes);
visitor.on_attribute("initial_broadcast_axes", m_initial_broadcast_axes);
return true;
}
shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::BroadcastLike::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
......
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
Broadcast(const Output<Node>& arg, Broadcast(const Output<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes); const AxisSet& broadcast_axes);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -94,7 +94,7 @@ namespace ngraph ...@@ -94,7 +94,7 @@ namespace ngraph
BroadcastLike(const Output<Node>& arg, BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg, const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes); const AxisSet& initial_broadcast_axes);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -154,7 +154,7 @@ namespace ngraph ...@@ -154,7 +154,7 @@ namespace ngraph
const Output<Node>& target_shape, const Output<Node>& target_shape,
const AutoBroadcastSpec& broadcast_spec = const AutoBroadcastSpec& broadcast_spec =
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -21,13 +21,19 @@ using namespace ngraph; ...@@ -21,13 +21,19 @@ using namespace ngraph;
constexpr NodeTypeInfo op::BroadcastDistributed::type_info; constexpr NodeTypeInfo op::BroadcastDistributed::type_info;
op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int root_id) op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int64_t root_id)
: Op({arg}) : Op({arg})
, m_root_id(root_id) , m_root_id(root_id)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::BroadcastDistributed::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("root_id", m_root_id);
return true;
}
void op::BroadcastDistributed::validate_and_infer_types() void op::BroadcastDistributed::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -47,12 +53,12 @@ shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector& ...@@ -47,12 +53,12 @@ shared_ptr<Node> op::BroadcastDistributed::copy_with_new_args(const NodeVector&
return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id); return make_shared<BroadcastDistributed>(new_args.at(0), m_root_id);
} }
int op::BroadcastDistributed::get_root_id() const int64_t op::BroadcastDistributed::get_root_id() const
{ {
return m_root_id; return m_root_id;
} }
void op::BroadcastDistributed::set_root_id(int root_id) void op::BroadcastDistributed::set_root_id(int64_t root_id)
{ {
m_root_id = root_id; m_root_id = root_id;
} }
...@@ -31,17 +31,17 @@ namespace ngraph ...@@ -31,17 +31,17 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"BroadcastDistributed", 0}; static constexpr NodeTypeInfo type_info{"BroadcastDistributed", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
BroadcastDistributed() = default; BroadcastDistributed() = default;
BroadcastDistributed(const Output<Node>& arg, int root_id = 0); BroadcastDistributed(const Output<Node>& arg, int64_t root_id = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const; int64_t get_root_id() const;
void set_root_id(int root_id); void set_root_id(int64_t root_id);
private: private:
int m_root_id; int64_t m_root_id;
}; };
} }
} }
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Ceiling(const Output<Node>& arg); Ceiling(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override { return true; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -36,6 +36,12 @@ op::Concat::Concat(const NodeVector& args, int64_t axis) ...@@ -36,6 +36,12 @@ op::Concat::Concat(const NodeVector& args, int64_t axis)
{ {
} }
bool op::Concat::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
return true;
}
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required."); NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors. /// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, int64_t axis); Concat(const NodeVector& args, int64_t axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::LogSoftmax::type_info;
op::LogSoftmax::LogSoftmax(const Output<Node>& data, int64_t axis)
: FusedOp({data})
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
NodeVector op::LogSoftmax::decompose_op() const
{
const auto data = input_value(0);
const auto data_shape = data.get_shape();
auto axis = ngraph::normalize_axis(this, m_axis, data_shape.size());
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
auto softmax = std::make_shared<ngraph::op::Softmax>(data, axes);
return {std::make_shared<ngraph::op::Log>(softmax)};
}
shared_ptr<Node> op::LogSoftmax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LogSoftmax>(new_args.at(0), m_axis);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief LogSoftmax operation
class LogSoftmax : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LogSoftmax", 0};
LogSoftmax() = default;
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a LogSoftmax node.
///
/// \param data Node that produces the first input tensor
/// \param axis Describes the axis of the inputs when coerced to 2D
LogSoftmax(const Output<Node>& data, int64_t axis);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int64_t get_axis() const { return m_axis; }
protected:
int64_t m_axis;
};
} // namespace op
} // namespace ngraph
...@@ -64,25 +64,10 @@ NodeVector op::MatMul::decompose_op() const ...@@ -64,25 +64,10 @@ NodeVector op::MatMul::decompose_op() const
auto A = input_value(0); auto A = input_value(0);
auto B = input_value(1); auto B = input_value(1);
// Specification is expecting that A & B have at least 2 dimenstions. const auto a_rank = A.get_shape().size();
// Missing dimensions are padded with 1. const auto b_rank = B.get_shape().size();
int a_rank = A.get_shape().size();
if (a_rank < 2)
{
A = a_rank == 0 ? make_shared<op::Reshape>(A, AxisVector{}, Shape{1, 1})
: make_shared<op::Reshape>(A, AxisVector{1}, Shape{1, A.get_shape()[0]});
a_rank = 2;
}
int b_rank = B.get_shape().size();
if (b_rank < 2)
{
B = b_rank == 0 ? make_shared<op::Reshape>(B, AxisVector{}, Shape{1, 1})
: make_shared<op::Reshape>(B, AxisVector{1}, Shape{1, B.get_shape()[0]});
b_rank = 2;
}
if (m_transpose_a) if (m_transpose_a && a_rank >= 2)
{ {
vector<size_t> axes_order(a_rank); vector<size_t> axes_order(a_rank);
// generate default axes_order. // generate default axes_order.
...@@ -92,7 +77,7 @@ NodeVector op::MatMul::decompose_op() const ...@@ -92,7 +77,7 @@ NodeVector op::MatMul::decompose_op() const
A = builder::reorder_axes(A, axes_order); A = builder::reorder_axes(A, axes_order);
} }
if (m_transpose_b) if (m_transpose_b && b_rank >= 2)
{ {
vector<size_t> axes_order(b_rank); vector<size_t> axes_order(b_rank);
iota(axes_order.begin(), axes_order.end(), 0); iota(axes_order.begin(), axes_order.end(), 0);
......
...@@ -26,8 +26,11 @@ using namespace ngraph; ...@@ -26,8 +26,11 @@ using namespace ngraph;
constexpr NodeTypeInfo op::SquaredDifference::type_info; constexpr NodeTypeInfo op::SquaredDifference::type_info;
op::SquaredDifference::SquaredDifference(const Output<Node>& x1, const Output<Node>& x2) op::SquaredDifference::SquaredDifference(const Output<Node>& x1,
const Output<Node>& x2,
const AutoBroadcastSpec& auto_broadcast)
: FusedOp({x1, x2}) : FusedOp({x1, x2})
, m_autobroadcast(auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -37,19 +40,14 @@ NodeVector op::SquaredDifference::decompose_op() const ...@@ -37,19 +40,14 @@ NodeVector op::SquaredDifference::decompose_op() const
const auto x1 = input_value(0); const auto x1 = input_value(0);
const auto x2 = input_value(1); const auto x2 = input_value(1);
const auto broadcasted = numpy_style_broadcast_values({x1, x2}); const auto difference = make_shared<op::Subtract>(x1, x2, m_autobroadcast);
const auto difference = broadcasted.at(0) - broadcasted.at(1);
return {difference * difference}; return {difference * difference};
} }
shared_ptr<Node> op::SquaredDifference::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::SquaredDifference::copy_with_new_args(const NodeVector& new_args) const
{ {
NODE_VALIDATION_CHECK(this, check_new_args_count(this, new_args);
new_args.size() == 2,
"Expected 2 elements in new_args for the SquaredDifference op but got ",
new_args.size());
return make_shared<SquaredDifference>(new_args.at(0), new_args.at(1)); return make_shared<SquaredDifference>(new_args.at(0), new_args.at(1), get_autob());
} }
...@@ -38,12 +38,24 @@ namespace ngraph ...@@ -38,12 +38,24 @@ namespace ngraph
/// ///
/// \param x1 First input tensor /// \param x1 First input tensor
/// \param x2 Second input tensor /// \param x2 Second input tensor
SquaredDifference(const Output<Node>& x1, const Output<Node>& x2); /// \param auto_broadcast Auto broadcast specification
SquaredDifference(const Output<Node>& x1,
const Output<Node>& x2,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const AutoBroadcastSpec& get_autob() const override { return m_autobroadcast; }
void set_autob(const AutoBroadcastSpec& auto_broadcast)
{
m_autobroadcast = auto_broadcast;
}
private:
AutoBroadcastSpec m_autobroadcast;
}; };
} } // namespace op
} } // namespace ngraph
...@@ -39,6 +39,7 @@ NGRAPH_OP(GRUCell, ngraph::op) ...@@ -39,6 +39,7 @@ NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op) NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op) NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op) NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op) NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op) NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op) NGRAPH_OP(MatMul, ngraph::op)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
using namespace std; using namespace std;
...@@ -34,6 +35,14 @@ op::Parameter::Parameter(const element::Type& element_type, ...@@ -34,6 +35,14 @@ op::Parameter::Parameter(const element::Type& element_type,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::Parameter::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("cacheable", m_cacheable);
visitor.on_attribute("shape", m_partial_shape);
visitor.on_attribute("element_type", m_element_type);
return true;
}
void op::Parameter::validate_and_infer_types() void op::Parameter::validate_and_infer_types()
{ {
Op::validate_and_infer_types(); Op::validate_and_infer_types();
......
...@@ -49,6 +49,8 @@ namespace ngraph ...@@ -49,6 +49,8 @@ namespace ngraph
const PartialShape& pshape, const PartialShape& pshape,
const bool cacheable = false); const bool cacheable = false);
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_parameter() const override { return true; } bool is_parameter() const override { return true; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -15,17 +15,112 @@ ...@@ -15,17 +15,112 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/enum_names.hpp"
using namespace ngraph; using namespace ngraph;
std::ostream& op::operator<<(std::ostream& s, const op::AutoBroadcastType& type) namespace ngraph
{ {
switch (type) template <>
EnumNames<op::PadMode>& EnumNames<op::PadMode>::get()
{ {
case op::AutoBroadcastType::NONE: s << "NONE"; break; static auto enum_names = EnumNames<op::PadMode>("op::PadMode",
case op::AutoBroadcastType::NUMPY: s << "NUMPY"; break; {{"CONSTANT", op::PadMode::CONSTANT},
case op::AutoBroadcastType::PDPD: s << "PDPD"; break; {"EDGE", op::PadMode::EDGE},
default: s << "Undefined Type"; {"REFLECT", op::PadMode::REFLECT},
{"SYMMETRIC", op::PadMode::SYMMETRIC}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::PadMode>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::PadMode& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::PadType>& EnumNames<op::PadType>::get()
{
static auto enum_names = EnumNames<op::PadType>("op::PadType",
{{"EXPLICIT", op::PadType::EXPLICIT},
{"SAME_LOWER", op::PadType::SAME_LOWER},
{"SAME_UPPER", op::PadType::SAME_UPPER},
{"VALID", op::PadType::VALID}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::PadType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::PadType& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::RoundingType>& EnumNames<op::RoundingType>::get()
{
static auto enum_names = EnumNames<op::RoundingType>(
"op::RoundingType",
{{"FLOOR", op::RoundingType::FLOOR}, {"CEIL", op::RoundingType::CEIL}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::RoundingType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::RoundingType& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::AutoBroadcastType>& EnumNames<op::AutoBroadcastType>::get()
{
static auto enum_names =
EnumNames<op::AutoBroadcastType>("op::AutoBroadcastType",
{{"NONE", op::AutoBroadcastType::NONE},
{"NUMPY", op::AutoBroadcastType::NUMPY},
{"PDPD", op::AutoBroadcastType::PDPD}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::AutoBroadcastType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::AutoBroadcastType& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::EpsMode>& EnumNames<op::EpsMode>::get()
{
static auto enum_names = EnumNames<op::EpsMode>(
"op::EpsMode", {{"ADD", op::EpsMode::ADD}, {"MAX", op::EpsMode::MAX}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::EpsMode>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::EpsMode& type)
{
return s << as_string(type);
}
template <>
EnumNames<op::TopKSortType>& EnumNames<op::TopKSortType>::get()
{
static auto enum_names =
EnumNames<op::TopKSortType>("op::TopKSortType",
{{"NONE", op::TopKSortType::NONE},
{"SORT_INDICES", op::TopKSortType::SORT_INDICES},
{"SORT_VALUES", op::TopKSortType::SORT_VALUES}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::TopKSortType>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::TopKSortType& type)
{
return s << as_string(type);
} }
return s;
} }
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include <cstddef> #include <cstddef>
#include <ostream> #include <ostream>
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/type.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
...@@ -32,6 +35,25 @@ namespace ngraph ...@@ -32,6 +35,25 @@ namespace ngraph
SYMMETRIC SYMMETRIC
}; };
std::ostream& operator<<(std::ostream& s, const PadMode& type);
}
template <>
class AttributeAdapter<op::PadMode> : public EnumAttributeAdapterBase<op::PadMode>
{
public:
AttributeAdapter(op::PadMode& value)
: EnumAttributeAdapterBase<op::PadMode>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PadMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Padding Type used for `Convolution` and `Pooling` /// \brief Padding Type used for `Convolution` and `Pooling`
/// ///
/// Follows ONNX padding type definitions /// Follows ONNX padding type definitions
...@@ -54,6 +76,25 @@ namespace ngraph ...@@ -54,6 +76,25 @@ namespace ngraph
NOTSET = EXPLICIT, NOTSET = EXPLICIT,
}; };
std::ostream& operator<<(std::ostream& s, const PadType& type);
}
template <>
class AttributeAdapter<op::PadType> : public EnumAttributeAdapterBase<op::PadType>
{
public:
AttributeAdapter(op::PadType& value)
: EnumAttributeAdapterBase<op::PadType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PadType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Rounding Type used for `Pooling` operators. /// \brief Rounding Type used for `Pooling` operators.
enum class RoundingType enum class RoundingType
{ {
...@@ -61,6 +102,25 @@ namespace ngraph ...@@ -61,6 +102,25 @@ namespace ngraph
CEIL = 1, CEIL = 1,
}; };
std::ostream& operator<<(std::ostream& s, const RoundingType& type);
}
template <>
class AttributeAdapter<op::RoundingType> : public EnumAttributeAdapterBase<op::RoundingType>
{
public:
AttributeAdapter(op::RoundingType& value)
: EnumAttributeAdapterBase<op::RoundingType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::RoundingType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Specifies the algorithm to use for implicit broadcasting of a tensor /// \brief Specifies the algorithm to use for implicit broadcasting of a tensor
/// to align with another tensor /// to align with another tensor
/// ///
...@@ -107,6 +167,26 @@ namespace ngraph ...@@ -107,6 +167,26 @@ namespace ngraph
PDPD PDPD
}; };
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
}
template <>
class AttributeAdapter<op::AutoBroadcastType>
: public EnumAttributeAdapterBase<op::AutoBroadcastType>
{
public:
AttributeAdapter(op::AutoBroadcastType& value)
: EnumAttributeAdapterBase<op::AutoBroadcastType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Specifies how eps is combined with L2 value /// \brief Specifies how eps is combined with L2 value
enum class EpsMode enum class EpsMode
{ {
...@@ -116,6 +196,25 @@ namespace ngraph ...@@ -116,6 +196,25 @@ namespace ngraph
MAX MAX
}; };
std::ostream& operator<<(std::ostream& s, const EpsMode& type);
}
template <>
class AttributeAdapter<op::EpsMode> : public EnumAttributeAdapterBase<op::EpsMode>
{
public:
AttributeAdapter(op::EpsMode& value)
: EnumAttributeAdapterBase<op::EpsMode>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::EpsMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
enum class TopKSortType enum class TopKSortType
{ {
// Returned values are not sorted // Returned values are not sorted
...@@ -125,7 +224,25 @@ namespace ngraph ...@@ -125,7 +224,25 @@ namespace ngraph
// Sort result based on element values // Sort result based on element values
SORT_VALUES, SORT_VALUES,
}; };
std::ostream& operator<<(std::ostream& s, const TopKSortType& type);
}
template <>
class AttributeAdapter<op::TopKSortType> : public EnumAttributeAdapterBase<op::TopKSortType>
{
public:
AttributeAdapter(op::TopKSortType& value)
: EnumAttributeAdapterBase<op::TopKSortType>(value)
{
}
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::TopKSortType>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
namespace op
{
/// \brief Implicit broadcast specification /// \brief Implicit broadcast specification
struct AutoBroadcastSpec struct AutoBroadcastSpec
{ {
...@@ -153,7 +270,5 @@ namespace ngraph ...@@ -153,7 +270,5 @@ namespace ngraph
return a.m_type == m_type && a.m_axis == m_axis; return a.m_type == m_type && a.m_axis == m_axis;
} }
}; };
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
} }
} }
...@@ -54,3 +54,9 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types() ...@@ -54,3 +54,9 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
{ {
validate_and_infer_elementwise_arithmetic(m_autob); validate_and_infer_elementwise_arithmetic(m_autob);
} }
bool op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
return true;
}
...@@ -86,6 +86,8 @@ namespace ngraph ...@@ -86,6 +86,8 @@ namespace ngraph
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool is_binary_elementwise_arithmetic() const override { return true; } bool is_binary_elementwise_arithmetic() const override { return true; }
bool supports_auto_broadcast() const override { return true; } bool supports_auto_broadcast() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -55,3 +55,9 @@ void op::util::BinaryElementwiseComparison::validate_and_infer_types() ...@@ -55,3 +55,9 @@ void op::util::BinaryElementwiseComparison::validate_and_infer_types()
set_output_type(0, element::boolean, args_pshape); set_output_type(0, element::boolean, args_pshape);
} }
bool op::util::BinaryElementwiseComparison::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
return true;
}
...@@ -92,6 +92,8 @@ namespace ngraph ...@@ -92,6 +92,8 @@ namespace ngraph
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; } bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_comparison() const override { return true; } bool is_binary_elementwise_comparison() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -52,3 +52,9 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types() ...@@ -52,3 +52,9 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types()
{ {
validate_and_infer_elementwise_logical(m_autob); validate_and_infer_elementwise_logical(m_autob);
} }
bool op::util::BinaryElementwiseLogical::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
return true;
}
...@@ -88,6 +88,8 @@ namespace ngraph ...@@ -88,6 +88,8 @@ namespace ngraph
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; } bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_logical() const override { return true; } bool is_binary_elementwise_logical() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -26,7 +26,7 @@ op::util::IndexReduction::IndexReduction() ...@@ -26,7 +26,7 @@ op::util::IndexReduction::IndexReduction()
} }
op::util::IndexReduction::IndexReduction(const Output<Node>& arg, op::util::IndexReduction::IndexReduction(const Output<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op({arg}) : Op({arg})
{ {
...@@ -35,7 +35,7 @@ op::util::IndexReduction::IndexReduction(const Output<Node>& arg, ...@@ -35,7 +35,7 @@ op::util::IndexReduction::IndexReduction(const Output<Node>& arg,
} }
op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg, op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op(check_single_output_args({arg})) : Op(check_single_output_args({arg}))
{ {
...@@ -45,7 +45,7 @@ op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg, ...@@ -45,7 +45,7 @@ op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg,
op::util::IndexReduction::IndexReduction(const std::string& node_type, op::util::IndexReduction::IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op(node_type, check_single_output_args({arg})) : Op(node_type, check_single_output_args({arg}))
{ {
...@@ -53,11 +53,11 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type, ...@@ -53,11 +53,11 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type,
set_index_element_type(index_element_type); set_index_element_type(index_element_type);
} }
size_t op::util::IndexReduction::get_reduction_axis() const uint64_t op::util::IndexReduction::get_reduction_axis() const
{ {
return m_axis; return m_axis;
} }
void op::util::IndexReduction::set_reduction_axis(size_t value) void op::util::IndexReduction::set_reduction_axis(uint64_t value)
{ {
m_axis = value; m_axis = value;
} }
...@@ -125,3 +125,10 @@ void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& /* adjoints ...@@ -125,3 +125,10 @@ void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& /* adjoints
{ {
throw ngraph_error("Forward-propagation-only operation"); throw ngraph_error("Forward-propagation-only operation");
} }
bool op::util::IndexReduction::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("index_element_type", m_index_element_type);
return true;
}
...@@ -35,27 +35,28 @@ namespace ngraph ...@@ -35,27 +35,28 @@ namespace ngraph
IndexReduction(); IndexReduction();
IndexReduction(const Output<Node>& arg, IndexReduction(const Output<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
IndexReduction(const std::shared_ptr<Node>& arg, IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
IndexReduction(const std::string& node_type, IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
size_t axis, uint64_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
public: public:
size_t get_reduction_axis() const; uint64_t get_reduction_axis() const;
void set_reduction_axis(size_t value); void set_reduction_axis(uint64_t value);
element::Type get_index_element_type() const; element::Type get_index_element_type() const;
void set_index_element_type(const element::Type& index_element_type); void set_index_element_type(const element::Type& index_element_type);
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
protected: protected:
size_t m_axis{0}; uint64_t m_axis{0};
element::Type m_index_element_type; element::Type m_index_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -1210,7 +1210,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1210,7 +1210,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
else if (typeid(ngraph::op::GeluBackpropFactor) == typeid(node)) else if (typeid(ngraph::op::GeluBackpropFactor) == typeid(node))
{ {
#if MKLDNN_VERSION_MAJOR < 1 #if MKLDNN_VERSION_MAJOR < 1
return ((node.input(0).get_element_type() == element::f32) ? true : false); // TODO: (gauri): need to differentiate which implementation : erf vs tanh
return false;
#else #else
// TODO: will be supported in mkldnn v1.1 // TODO: will be supported in mkldnn v1.1
return false; return false;
...@@ -1219,7 +1220,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1219,7 +1220,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
else if (typeid(ngraph::op::Gelu) == typeid(node)) else if (typeid(ngraph::op::Gelu) == typeid(node))
{ {
#if MKLDNN_VERSION_MAJOR < 1 #if MKLDNN_VERSION_MAJOR < 1
return ((node.input(0).get_element_type() == element::f32) ? true : false); // TODO: (gauri): need to differentiate which implementation : erf vs tanh
return false;
#else #else
// TODO: will be supported in mkldnn v1.1 // TODO: will be supported in mkldnn v1.1
return false; return false;
......
...@@ -22,9 +22,5 @@ lrn_across_nw ...@@ -22,9 +22,5 @@ lrn_across_nw
lrn_across_empty lrn_across_empty
lrn_6D_across_2_axes lrn_6D_across_2_axes
# Gelu tests not supported in CPU backend, we use mkldnn gelubackprop (and not factor)
gelu_backprop_factor_f32
backwards_gelu_f32
# ONNX TopK with dynamic K # ONNX TopK with dynamic K
top_k_opset_10 top_k_opset_10
...@@ -83,6 +83,7 @@ ...@@ -83,6 +83,7 @@
#include "ngraph/op/fused/gru_cell.hpp" #include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp" #include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp" #include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp" #include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp" #include "ngraph/op/fused/matmul.hpp"
...@@ -1823,6 +1824,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1823,6 +1824,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast")); args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break; break;
} }
case OP_TYPEID::LogSoftmax:
{
auto axis = node_js.at("axis").get<int64_t>();
node = make_shared<op::LogSoftmax>(args[0], axis);
break;
}
case OP_TYPEID::LRN: case OP_TYPEID::LRN:
{ {
auto alpha = node_js.at("alpha").get<double>(); auto alpha = node_js.at("alpha").get<double>();
...@@ -2605,7 +2612,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2605,7 +2612,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::SquaredDifference: case OP_TYPEID::SquaredDifference:
{ {
node = make_shared<op::SquaredDifference>(args[0], args[1]); node = make_shared<op::SquaredDifference>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break; break;
} }
case OP_TYPEID::Squeeze: case OP_TYPEID::Squeeze:
...@@ -3585,6 +3593,12 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3585,6 +3593,12 @@ json JSONSerializer::serialize_node(const Node& n)
} }
break; break;
} }
case OP_TYPEID::LogSoftmax:
{
auto tmp = static_cast<const op::LogSoftmax*>(&n);
node["axis"] = tmp->get_axis();
break;
}
case OP_TYPEID::LRN: case OP_TYPEID::LRN:
{ {
auto tmp = static_cast<const op::LRN*>(&n); auto tmp = static_cast<const op::LRN*>(&n);
...@@ -4067,7 +4081,14 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4067,7 +4081,14 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::Sqrt: { break; case OP_TYPEID::Sqrt: { break;
} }
case OP_TYPEID::SquaredDifference: { break; case OP_TYPEID::SquaredDifference:
{
auto tmp = static_cast<const op::SquaredDifference*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
} }
case OP_TYPEID::Squeeze: { break; case OP_TYPEID::Squeeze: { break;
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <cstdint>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
......
...@@ -71,10 +71,11 @@ float16::float16(float value) ...@@ -71,10 +71,11 @@ float16::float16(float value)
// denorm // denorm
biased_exp = 0; biased_exp = 0;
raw_frac |= hidden_one; raw_frac |= hidden_one;
uint32_t shift = (-15 - exp) + (23 - frac_size) + 1; uint32_t exp_shift = (-15 - exp) + 1;
raw_frac = (raw_frac + (hidden_one >> (shift + 1))) >> shift; uint32_t shift = exp_shift + (23 - frac_size);
raw_frac = (raw_frac + (hidden_one >> (frac_size - exp_shift + 1))) >> shift;
} }
else if (exp > 15) else if (exp > 15 || (exp == 15 && raw_frac > 0x7fef00 /* numpy overflow value */))
{ {
biased_exp = 0x1F; biased_exp = 0x1F;
raw_frac = 0; raw_frac = 0;
......
...@@ -794,3 +794,27 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -794,3 +794,27 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
return dim; return dim;
} }
std::size_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
{
const auto axis_range_min = -tensor_rank;
const auto axis_range_max = tensor_rank - 1;
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node->description(),
"Parameter axis ",
axis,
" out of the tensor rank [-",
axis_range_min,
", ",
axis_range_max,
"].");
if (axis < 0)
{
axis = axis + tensor_rank;
}
return static_cast<size_t>(axis);
}
...@@ -102,4 +102,6 @@ namespace ngraph ...@@ -102,4 +102,6 @@ namespace ngraph
const AxisSet& new_axis_mask, const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask, const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask); const AxisSet& ellipsis_mask);
std::size_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
} }
...@@ -46,6 +46,7 @@ set(SRC ...@@ -46,6 +46,7 @@ set(SRC
aligned_buffer.cpp aligned_buffer.cpp
all_close_f.cpp all_close_f.cpp
assertion.cpp assertion.cpp
attributes.cpp
bfloat16.cpp bfloat16.cpp
build_graph.cpp build_graph.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
...@@ -137,6 +138,7 @@ set(SRC ...@@ -137,6 +138,7 @@ set(SRC
type_prop/hard_sigmoid.cpp type_prop/hard_sigmoid.cpp
type_prop/index_reduction.cpp type_prop/index_reduction.cpp
type_prop/layer_norm.cpp type_prop/layer_norm.cpp
type_prop/log_softmax.cpp
type_prop/lrn.cpp type_prop/lrn.cpp
type_prop/lstm_cell.cpp type_prop/lstm_cell.cpp
type_prop/lstm_sequence.cpp type_prop/lstm_sequence.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
enum class TuringModel
{
XL400,
XL1200
};
namespace ngraph
{
template <>
EnumNames<TuringModel>& EnumNames<TuringModel>::get()
{
static auto enum_names = EnumNames<TuringModel>(
"TuringModel", {{"XL400", TuringModel::XL400}, {"XL1200", TuringModel::XL1200}});
return enum_names;
}
template <>
class AttributeAdapter<TuringModel> : public EnumAttributeAdapterBase<TuringModel>
{
public:
AttributeAdapter(TuringModel& value)
: EnumAttributeAdapterBase<TuringModel>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<TuringModel>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
constexpr DiscreteTypeInfo AttributeAdapter<TuringModel>::type_info;
}
// Given a Turing machine program and data, return scalar 1 if the program would
// complete, 1 if it would not.
class Oracle : public op::Op
{
public:
Oracle(const Output<Node>& program,
const Output<Node>& data,
TuringModel turing_model,
uint64_t model_version,
uint8_t rev,
const string& serial_number,
bool enable_turbo,
const std::vector<uint64_t>& hyper_parameters,
const std::vector<int64_t>& ultra_parameters)
: Op({program, data})
, m_turing_model(turing_model)
, m_model_version(model_version)
, m_rev(rev)
, m_serial_number(serial_number)
, m_enable_turbo(enable_turbo)
, m_hyper_parameters(hyper_parameters)
, m_ultra_parameters(ultra_parameters)
{
}
static constexpr NodeTypeInfo type_info{"Oracle", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Oracle() = default;
TuringModel get_turing_model() const { return m_turing_model; }
uint64_t get_model_version() const { return m_model_version; }
const string& get_serial_number() const { return m_serial_number; }
bool get_enable_turbo() const { return m_enable_turbo; }
const vector<uint64_t>& get_hyper_parameters() const { return m_hyper_parameters; }
const vector<int64_t>& get_ultra_parameters() const { return m_ultra_parameters; }
shared_ptr<Node> copy_with_new_args(const NodeVector& args) const override
{
return make_shared<Oracle>(args[0],
args[1],
m_turing_model,
m_model_version,
m_rev,
m_serial_number,
m_enable_turbo,
m_hyper_parameters,
m_ultra_parameters);
}
void validate_and_infer_types() override { set_output_type(0, element::i64, {}); }
bool visit_attributes(AttributeVisitor& visitor) override
{
visitor.on_attribute("turing_model", m_turing_model);
visitor.on_attribute("model_version", m_model_version);
visitor.on_attribute("rev", m_rev);
visitor.on_attribute("serial_number", m_serial_number);
visitor.on_attribute("enable_turbo", m_enable_turbo);
visitor.on_attribute("hyper_parameters", m_hyper_parameters);
visitor.on_attribute("ultra_parameters", m_ultra_parameters);
return true;
}
protected:
TuringModel m_turing_model;
uint64_t m_model_version;
int8_t m_rev;
string m_serial_number;
bool m_enable_turbo;
vector<uint64_t> m_hyper_parameters;
vector<int64_t> m_ultra_parameters;
};
constexpr NodeTypeInfo Oracle::type_info;
class NodeSaver : public AttributeVisitor
{
public:
NodeSaver(shared_ptr<Node> node)
: m_node_type_info(node->get_type_info())
{
node->visit_attributes(*this);
}
const NodeTypeInfo& get_node_type_info() { return m_node_type_info; }
string& get_string(const string& name) { return m_strings.at(name); }
bool get_bool(const string& name) { return m_bools.at(name); }
double get_double(const string& name) { return m_doubles.at(name); }
int64_t get_signed(const string& name) { return m_signeds.at(name); }
uint64_t get_unsigned(const string& name) { return m_unsigneds.at(name); }
vector<int64_t>& get_signed_vector(const string& name) { return m_signed_vectors.at(name); }
void set_string(const string& name, const string& value) { m_strings[name] = value; }
void set_bool(const string& name, bool value) { m_bools[name] = value; }
void set_double(const string& name, double value) { m_doubles[name] = value; }
void set_signed(const string& name, int64_t value) { m_signeds[name] = value; }
void set_unsigned(const string& name, uint64_t value) { m_unsigneds[name] = value; }
void set_signed_vector(const string& name, const vector<int64_t>& value)
{
m_signed_vectors[name] = value;
}
void on_attribute(const string& name, string& value) override { set_string(name, value); };
void on_attribute(const string& name, bool& value) override { set_bool(name, value); }
void on_adapter(const string& name, ValueAccessor<void>& adapter) override
{
NGRAPH_CHECK(false, "name cannot be marshalled");
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const string& name, ValueAccessor<string>& adapter) override
{
set_string(name, adapter.get());
};
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
set_signed_vector(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override
{
set_signed(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<double>& adapter) override
{
set_double(name, adapter.get());
}
protected:
NodeTypeInfo m_node_type_info;
map<string, string> m_strings;
map<string, bool> m_bools;
map<string, double> m_doubles;
map<string, int64_t> m_signeds;
map<string, uint64_t> m_unsigneds;
map<string, vector<int64_t>> m_signed_vectors;
};
class NodeBuilder : public AttributeVisitor
{
public:
NodeBuilder(const shared_ptr<Node>& node)
: m_values(node)
{
}
shared_ptr<Node> create()
{
shared_ptr<Node> node(FactoryRegistry<Node>::get().create(m_values.get_node_type_info()));
node->visit_attributes(*this);
node->validate_and_infer_types();
return node;
}
void on_attribute(const string& name, string& value) override
{
value = m_values.get_string(name);
};
void on_attribute(const string& name, bool& value) override { value = m_values.get_bool(name); }
void on_adapter(const string& name, ValueAccessor<void>& adapter) override
{
NGRAPH_CHECK(false, "name cannot be marshalled");
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const string& name, ValueAccessor<string>& adapter) override
{
adapter.set(m_values.get_string(name));
};
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
adapter.set(m_values.get_signed_vector(name));
}
void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override
{
adapter.set(m_values.get_signed(name));
}
void on_adapter(const string& name, ValueAccessor<double>& adapter) override
{
adapter.set(m_values.get_double(name));
}
protected:
NodeSaver m_values;
};
TEST(attributes, user_op)
{
FactoryRegistry<Node>::get().register_factory<Oracle>();
auto program = make_shared<op::Parameter>(element::i32, Shape{200});
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto oracle = make_shared<Oracle>(program,
data,
TuringModel::XL1200,
2,
4,
"12AU7",
true,
vector<uint64_t>{1, 2, 4, 8},
vector<int64_t>{-1, -2, -4, -8});
NodeBuilder builder(oracle);
auto g_oracle = as_type_ptr<Oracle>(builder.create());
EXPECT_EQ(g_oracle->get_turing_model(), oracle->get_turing_model());
EXPECT_EQ(g_oracle->get_model_version(), oracle->get_model_version());
EXPECT_EQ(g_oracle->get_serial_number(), oracle->get_serial_number());
EXPECT_EQ(g_oracle->get_enable_turbo(), oracle->get_enable_turbo());
EXPECT_EQ(g_oracle->get_hyper_parameters(), oracle->get_hyper_parameters());
EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters());
}
...@@ -90,8 +90,20 @@ TEST(float16, assigns) ...@@ -90,8 +90,20 @@ TEST(float16, assigns)
TEST(float16, values) TEST(float16, values)
{ {
std::vector<double> f32vec{2.73786e-05, 3.87722e-05, -0.0223043}; std::vector<double> f32vec{2.73786e-05,
std::vector<uint16_t> intvals = {459, 650, 42422}; 3.87722e-05,
-0.0223043,
5.10779e-05,
-5.10779e-05,
-2.553895e-05,
-0.0001021558,
5.960464477539063e-08,
8.940696716308594e-08,
65536.0,
65519.0,
65520.0};
std::vector<uint16_t> intvals = {
459, 650, 42422, 857, 0x8359, 0x81ac, 0x86b2, 0x01, 0x02, 0x7c00, 0x7bff, 0x7c00};
for (size_t i = 0; i < f32vec.size(); ++i) for (size_t i = 0; i < f32vec.size(); ++i)
{ {
float16 fp16val = f32vec.at(i); float16 fp16val = f32vec.at(i);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, log_softmax)
{
const auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
const auto axis = 2;
try
{
const auto log_softmax = make_shared<op::LogSoftmax>(data, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid axis value not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis "));
}
catch (...)
{
FAIL() << "Log softmax failed for unexpected reason";
}
}
...@@ -34,7 +34,7 @@ TEST(type_prop, squared_difference) ...@@ -34,7 +34,7 @@ TEST(type_prop, squared_difference)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("axes are incompatible")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
} }
const auto clamp = make_shared<op::SquaredDifference>(x1, x3); const auto clamp = make_shared<op::SquaredDifference>(x1, x3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment