Commit 305a9a8a authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/ninja

parents c5787cc5 f1608316
find_program(GRAPHVIZ_EXECUTABLE dot)
# Handle REQUIRED and QUIET arguments
# this will also set GRAPHVIZ_FOUND to true if GRAPHVIZ_EXECUTABLE exists
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Graphviz
"Failed to locate graphviz executable"
GRAPHVIZ_EXECUTABLE)
...@@ -15,6 +15,11 @@ set (SRC ...@@ -15,6 +15,11 @@ set (SRC
tree.cpp tree.cpp
util.cpp util.cpp
log.cpp log.cpp
ngraph/descriptor/input.cpp
ngraph/descriptor/output.cpp
ngraph/descriptor/tensor.cpp
ngraph/descriptor/tensor_view.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/convert.cpp ops/convert.cpp
...@@ -24,6 +29,7 @@ set (SRC ...@@ -24,6 +29,7 @@ set (SRC
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/tuple.cpp ops/tuple.cpp
ops/unary_elementwise_builtin.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
ngraph/node.cpp ngraph/node.cpp
...@@ -36,15 +42,22 @@ set(NGRAPH_INCLUDE_PATH ...@@ -36,15 +42,22 @@ set(NGRAPH_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/ngraph ${CMAKE_CURRENT_SOURCE_DIR}/ngraph
) )
# find_program (GRAPHVIZ dot)
# message (STATUS "graphviz '${GRAPHVIZ}'")
find_package(Graphviz)
if (GRAPHVIZ_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DGRAPHVIZ_FOUND")
endif()
include_directories("${NGRAPH_INCLUDE_PATH}") include_directories("${NGRAPH_INCLUDE_PATH}")
add_library(ngraph SHARED ${SRC}) add_library(ngraph SHARED ${SRC})
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}") target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
if ( APPLE ) if (APPLE)
set_property( TARGET ngraph PROPERTY PREFIX "lib" ) set_property(TARGET ngraph PROPERTY PREFIX "lib")
set_property( TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so" ) set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so")
set_property( TARGET ngraph PROPERTY SUFFIX "" ) set_property(TARGET ngraph PROPERTY SUFFIX "")
endif() endif()
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
...@@ -63,7 +76,7 @@ set(CMAKE_INSTALL_PREFIX "$ENV{HOME}/ngraph_dist" CACHE PATH "Install directory" ...@@ -63,7 +76,7 @@ set(CMAKE_INSTALL_PREFIX "$ENV{HOME}/ngraph_dist" CACHE PATH "Install directory"
message (STATUS "Installation directory: ${CMAKE_INSTALL_PREFIX}") message (STATUS "Installation directory: ${CMAKE_INSTALL_PREFIX}")
message (STATUS "To Override use: cmake -DCMAKE_INSTALL_PREFIX=/foo -P cmake_install.cmake") message (STATUS "To Override use: cmake -DCMAKE_INSTALL_PREFIX=/foo -P cmake_install.cmake")
install(TARGETS ngraph DESTINATION ${CMAKE_INSTALL_PREFIX} ) install(TARGETS ngraph DESTINATION ${CMAKE_INSTALL_PREFIX})
install(DIRECTORY install(DIRECTORY
${CMAKE_CURRENT_SOURCE_DIR}/ngraph ${CMAKE_CURRENT_SOURCE_DIR}/ngraph
DESTINATION ${CMAKE_INSTALL_PREFIX} DESTINATION ${CMAKE_INSTALL_PREFIX}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace descriptor;
Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, const shared_ptr<Output>& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
, m_arg_index(arg_index)
, m_output(output)
{
output->add_input(this);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
namespace ngraph
{
namespace descriptor
{
class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class Input : public std::enable_shared_from_this<Input>
{
Input(const Input&) = delete;
Input& operator=(const Input&) = delete;
public:
/// @param node The node that owns this input; not shared to prevent owner loop
/// @param index The position of this this tensor in all input tensors
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(Node* node,
size_t index,
size_t argno,
size_t arg_index,
const std::shared_ptr<Output>& output);
std::shared_ptr<Node> get_node() { return m_node->shared_from_this(); }
size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; }
std::shared_ptr<Output> get_output() const { return m_output; }
protected:
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
std::shared_ptr<Output> m_output;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace descriptor;
Output::Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view)
: m_node(node)
, m_index(index)
, m_tensor_view(tensor_view)
{
}
// Add an input to the vector of inputs that use this output.
void Output::add_input(Input* input)
{
m_inputs.insert(input);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <set>
#include "descriptor/tensor_view.hpp"
namespace ngraph
{
namespace descriptor
{
// Describes an output tensor of an op
class Output : public std::enable_shared_from_this<Output>
{
Output(const Output&) = delete;
Output& operator=(const Output&) = delete;
public:
/// @param node Node that owns this output. Not shared to prevent owner loop.
/// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
std::shared_ptr<Node> get_node() const { return m_node->shared_from_this(); }
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
protected:
Node* m_node;
size_t m_index;
std::shared_ptr<TensorView> m_tensor_view;
std::set<Input*> m_inputs;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor.hpp"
using namespace ngraph;
using namespace descriptor;
Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view)
: m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
namespace ngraph
{
namespace element
{
class Type;
}
namespace descriptor
{
class TensorView;
class PrimaryTensorView;
class Tensor
{
friend class PrimaryTensorView;
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view);
protected:
const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor_view.hpp"
using namespace ngraph;
using namespace descriptor;
const Tensor& PrimaryTensorView::get_tensor() const
{
return m_tensor;
}
Tensor& PrimaryTensorView::get_tensor()
{
return m_tensor;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "descriptor/tensor.hpp"
#include "shape.hpp"
#include "type.hpp"
namespace ngraph
{
namespace descriptor
{
class Tensor;
class TensorViewLayout;
// Describes a view of an instantiated tensor
class TensorView : public std::enable_shared_from_this<TensorView>
{
TensorView(const TensorView&) = delete;
TensorView& operator=(const TensorView&) = delete;
protected:
TensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type)
: m_tensor_view_type(tensor_view_type)
{
}
public:
virtual ~TensorView() {}
virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0;
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return m_tensor_view_type;
}
const std::shared_ptr<TensorViewLayout>& get_tensor_view_layout() const
{
return m_tensor_view_layout;
}
void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout)
{
m_tensor_view_layout = tensor_view_layout;
}
protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type;
std::shared_ptr<TensorViewLayout> m_tensor_view_layout;
};
// A PrimaryTensorView owns the tensor. All other views are the result
// of some index operation on the primary view.
class PrimaryTensorView : public TensorView
{
public:
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type)
: TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this)
{
}
virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override;
protected:
Tensor m_tensor;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
namespace ngraph
{
namespace descriptor
{
using Strides = std::vector<size_t>;
class TensorViewLayout
{
protected:
Strides m_strides;
};
}
}
...@@ -23,13 +23,24 @@ ...@@ -23,13 +23,24 @@
#include "except.hpp" #include "except.hpp"
#include "function.hpp" #include "function.hpp"
#include "node.hpp" #include "node.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "descriptor/tensor_view.hpp"
#include "descriptor/tensor_view_layout.hpp"
#include "descriptor/tensor.hpp"
#include "op.hpp" #include "op.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp" #include "ops/broadcast.hpp"
#include "ops/ceiling.hpp"
#include "ops/concatenate.hpp" #include "ops/concatenate.hpp"
#include "ops/constant.hpp" #include "ops/constant.hpp"
#include "ops/convert.hpp" #include "ops/convert.hpp"
#include "ops/divide.hpp"
#include "ops/dot.hpp" #include "ops/dot.hpp"
#include "ops/floor.hpp"
#include "ops/multiply.hpp"
#include "ops/parameter.hpp" #include "ops/parameter.hpp"
#include "ops/subtract.hpp"
#include "ops/tuple.hpp" #include "ops/tuple.hpp"
#include "shape.hpp" #include "shape.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "node.hpp" #include "ngraph.hpp"
#include "op.hpp"
size_t ngraph::Node::m_next_instance_id = 0; using namespace std;
using namespace ngraph;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, size_t Node::m_next_instance_id = 0;
std::shared_ptr<ValueType> value_type)
Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> value_type)
: m_arguments(arguments) : m_arguments(arguments)
, m_value_type(value_type) , m_value_type(value_type)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
...@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, ...@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
} }
} }
void ngraph::Node::set_value_type_checked(const std::shared_ptr<ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<ValueType>& value_type)
{ {
if (nullptr == m_value_type){ if (nullptr == m_value_type)
{
m_value_type = value_type; m_value_type = value_type;
} else { }
if (*m_value_type != *value_type){ else
throw ngraph::ngraph_error("Setting value type to a different ValueType"); {
if (*m_value_type != *value_type)
{
throw ngraph_error("Setting value type to a different ValueType");
}
}
}
void Node::assign_tensors()
{
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
size_t i = 0;
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt);
auto output = make_shared<descriptor::Output>(this, i++, tensor_view_descriptor);
m_outputs.push_back(output);
}
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (auto output : arg->get_outputs())
{
auto input = make_shared<descriptor::Input>(this, i++, argno, arg_index++, output);
m_inputs.push_back(input);
} }
argno++;
} }
} }
bool ngraph::Node::is_op() const bool Node::is_op() const
{ {
return dynamic_cast<const ngraph::Op*>(this) != nullptr; return dynamic_cast<const Op*>(this) != nullptr;
} }
bool ngraph::Node::is_parameter() const bool Node::is_parameter() const
{ {
return dynamic_cast<const ngraph::op::Parameter*>(this) != nullptr; return dynamic_cast<const op::Parameter*>(this) != nullptr;
} }
namespace ngraph namespace ngraph
{ {
std::ostream& operator<<(std::ostream& out, const ngraph::Node& node) ostream& operator<<(ostream& out, const Node& node)
{ {
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node); auto op_tmp = dynamic_cast<const Op*>(&node);
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node); auto parameter_tmp = dynamic_cast<const Op*>(&node);
if (op_tmp) if (op_tmp)
{ {
out << "Op(" << op_tmp->get_node_id() << ")"; out << "Op(" << op_tmp->get_node_id() << ")";
......
...@@ -27,6 +27,12 @@ namespace ngraph ...@@ -27,6 +27,12 @@ namespace ngraph
{ {
class Op; class Op;
namespace descriptor
{
class Input;
class Output;
}
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
...@@ -53,6 +59,10 @@ namespace ngraph ...@@ -53,6 +59,10 @@ namespace ngraph
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
virtual void propagate_types() = 0; virtual void propagate_types() = 0;
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
...@@ -94,7 +104,9 @@ namespace ngraph ...@@ -94,7 +104,9 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<std::shared_ptr<descriptor::Input>> get_inputs() { return m_inputs; }
std::vector<std::shared_ptr<descriptor::Output>> get_outputs() {return m_outputs; }
protected: protected:
Nodes m_arguments; Nodes m_arguments;
...@@ -103,5 +115,7 @@ namespace ngraph ...@@ -103,5 +115,7 @@ namespace ngraph
std::string m_name; std::string m_name;
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
std::vector<std::shared_ptr<descriptor::Input>> m_inputs;
std::vector<std::shared_ptr<descriptor::Output>> m_outputs;
}; };
} }
...@@ -61,10 +61,6 @@ namespace ngraph ...@@ -61,10 +61,6 @@ namespace ngraph
{ {
public: public:
virtual std::string description() const override { return "Builtin"; } virtual std::string description() const override { return "Builtin"; }
/// Name of the builtin op, for debugging and logging.
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected: protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args) Builtin(const std::vector<std::shared_ptr<Node>>& args)
...@@ -73,58 +69,60 @@ namespace ngraph ...@@ -73,58 +69,60 @@ namespace ngraph
} }
}; };
class Abs : public Builtin /// Index ops create a new way to index the same tensor elements
class IndexBuiltin : public Builtin
{ {
public: protected:
Abs(const std::shared_ptr<Node>& arg0) IndexBuiltin(const std::shared_ptr<Node>& arg)
: Builtin({arg0}) : Builtin(Nodes{arg})
{ {
} }
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
}; };
class Add : public Builtin /// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin
{ {
public: protected:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) UnaryElementwiseBuiltin(const std::shared_ptr<Node>& arg)
: Builtin({arg0, arg1}) : Builtin(Nodes{arg})
{ {
} }
virtual std::string get_op_class_name() const override { return "Add"; }
//virtual void propagate_types() override; public:
virtual void propagate_types() override;
}; };
class Ceiling : public Builtin /// Op(X, Y)[I] = op(X[I], Y[I])
class BinaryElementwiseBuiltin : public Builtin
{ {
public: protected:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) BinaryElementwiseBuiltin(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : Builtin(Nodes{arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "Ceiling"; } public:
//virtual void propagate_types() override; virtual void propagate_types() override;
}; };
class Divide : public Builtin class Abs : public UnaryElementwiseBuiltin
{ {
public: public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Abs(const std::shared_ptr<Node>& arg0)
: Builtin({arg0, arg1}) : UnaryElementwiseBuiltin({arg0})
{ {
} }
virtual std::string get_op_class_name() const override { return "Divide"; } virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Equal : public Builtin class Equal : public BinaryElementwiseBuiltin
{ {
public: public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -132,11 +130,11 @@ namespace ngraph ...@@ -132,11 +130,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Exp : public Builtin class Exp : public UnaryElementwiseBuiltin
{ {
public: public:
Exp(const std::shared_ptr<Node>& arg0) Exp(const std::shared_ptr<Node>& arg0)
: Builtin({arg0}) : UnaryElementwiseBuiltin(arg0)
{ {
} }
...@@ -144,23 +142,11 @@ namespace ngraph ...@@ -144,23 +142,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Floor : public Builtin class Greater : public BinaryElementwiseBuiltin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
//virtual void propagate_types() override;
};
class Greater : public Builtin
{ {
public: public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -168,11 +154,11 @@ namespace ngraph ...@@ -168,11 +154,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Less : public Builtin class Less : public BinaryElementwiseBuiltin
{ {
public: public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -180,11 +166,11 @@ namespace ngraph ...@@ -180,11 +166,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Log : public Builtin class Log : public UnaryElementwiseBuiltin
{ {
public: public:
Log(const std::shared_ptr<Node>& arg0) Log(const std::shared_ptr<Node>& arg0)
: Builtin({arg0}) : UnaryElementwiseBuiltin(arg0)
{ {
} }
...@@ -192,11 +178,11 @@ namespace ngraph ...@@ -192,11 +178,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Maximum : public Builtin class Maximum : public BinaryElementwiseBuiltin
{ {
public: public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -204,11 +190,11 @@ namespace ngraph ...@@ -204,11 +190,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Minimum : public Builtin class Minimum : public BinaryElementwiseBuiltin
{ {
public: public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -216,23 +202,11 @@ namespace ngraph ...@@ -216,23 +202,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Multiply : public Builtin class Negative : public UnaryElementwiseBuiltin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
//virtual void propagate_types() override;
};
class Negative : public Builtin
{ {
public: public:
Negative(const std::shared_ptr<Node>& arg0) Negative(const std::shared_ptr<Node>& arg0)
: Builtin({arg0}) : UnaryElementwiseBuiltin(arg0)
{ {
} }
...@@ -240,11 +214,11 @@ namespace ngraph ...@@ -240,11 +214,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Power : public Builtin class Power : public BinaryElementwiseBuiltin
{ {
public: public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -252,11 +226,11 @@ namespace ngraph ...@@ -252,11 +226,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Remainder : public Builtin class Remainder : public BinaryElementwiseBuiltin
{ {
public: public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -264,11 +238,11 @@ namespace ngraph ...@@ -264,11 +238,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Reshape : public Builtin class Reshape : public IndexBuiltin
{ {
public: public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape) Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: Builtin({arg0}) : IndexBuiltin(arg0)
, m_shape(shape) , m_shape(shape)
{ {
} }
...@@ -278,17 +252,5 @@ namespace ngraph ...@@ -278,17 +252,5 @@ namespace ngraph
protected: protected:
Shape m_shape; Shape m_shape;
}; };
class Subtract : public Builtin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
//virtual void propagate_types() override;
};
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Add : public BinaryElementwiseBuiltin
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
};
}
}
...@@ -18,7 +18,7 @@ namespace ngraph ...@@ -18,7 +18,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Broadcast : public Builtin class Broadcast : public IndexBuiltin
{ {
public: public:
/// ///
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
Broadcast(const std::shared_ptr<Node>& arg, Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: Builtin({arg}) : IndexBuiltin(arg)
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Ceiling : public BinaryElementwiseBuiltin
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
};
}
}
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Convert : public Builtin class Convert : public UnaryElementwiseBuiltin
{ {
public: public:
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type) Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: Builtin({arg}) : UnaryElementwiseBuiltin({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Divide : public BinaryElementwiseBuiltin
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
};
}
}
...@@ -21,7 +21,25 @@ namespace ngraph ...@@ -21,7 +21,25 @@ namespace ngraph
class Dot : public Builtin class Dot : public Builtin
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// Computes the dot product of two tensors.
///
/// There are three possible cases:
/// (1) arg0 or arg1 is 0-dimensional. Then, we treat the 0-dimensional
/// argument(s) as scalars and compute a scalar-tensor or
/// scalar-scalar product.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then
/// the result will have shape {1,2,3}.)
///
/// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing
/// on the innermost (rightmost) dimensions of arg0 and arg1.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then
/// the result will have shape {1,2}.)
///
/// (3) arg1 is more than 1-dimensional. Then, we compute a dot product
/// reducing on the innermost (rightmost) dimension of arg0, and the
/// next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then
/// the result will have shape {3,3}.)
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : Builtin({arg0, arg1})
{ {
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Floor : public BinaryElementwiseBuiltin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Multiply : public BinaryElementwiseBuiltin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Subtract : public BinaryElementwiseBuiltin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
};
}
}
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <unordered_map>
#include <deque>
#include "topological_sort.hpp" #include "topological_sort.hpp"
#include "node.hpp" #include "node.hpp"
#include "util.hpp" #include "util.hpp"
...@@ -19,31 +22,19 @@ ...@@ -19,31 +22,19 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first - 1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p) void ngraph::TopologicalSort::process(node_ptr p)
{ {
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](node_ptr node) { traverse_nodes(p, [&](node_ptr node) {
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()]; node_depencency_count[node.get()] = node->get_arguments().size();
node_list.push_back(node.get()); if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
}); });
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0) while (independent_nodes.size() > 0)
{ {
auto independent_node = independent_nodes.front(); auto independent_node = independent_nodes.front();
...@@ -52,12 +43,22 @@ void ngraph::TopologicalSort::process(node_ptr p) ...@@ -52,12 +43,22 @@ void ngraph::TopologicalSort::process(node_ptr p)
for (auto user : independent_node->users()) for (auto user : independent_node->users())
{ {
promote_node(user); node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
} }
} }
} }
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const const std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{
return m_sorted_list;
}
std::list<Node*>& ngraph::TopologicalSort::get_sorted_list()
{ {
return m_sorted_list; return m_sorted_list;
} }
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
#pragma once #pragma once
#include <list>
#include <map>
#include <memory> #include <memory>
#include <vector> #include <list>
namespace ngraph namespace ngraph
{ {
...@@ -32,11 +30,11 @@ public: ...@@ -32,11 +30,11 @@ public:
TopologicalSort() {} TopologicalSort() {}
void process(node_ptr); void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const; const std::list<Node*>& get_sorted_list() const;
std::list<Node*>& get_sorted_list();
private: private:
void promote_node(Node* n); void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes; std::list<Node*> m_sorted_list;
std::vector<Node*> m_sorted_list;
}; };
...@@ -34,10 +34,13 @@ namespace ngraph ...@@ -34,10 +34,13 @@ namespace ngraph
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType& that) const = 0; virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const { return !(*this == that); } bool operator!=(const ValueType& that) const { return !(*this == that); }
/// Add tensor views in depth-first order.
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
}; };
/// Describes a tensor view; an element type and a shape. /// Describes a tensor view; an element type and a shape.
class TensorViewType : public ValueType class TensorViewType : public ValueType, public std::enable_shared_from_this<TensorViewType>
{ {
public: public:
/// /param element_type The type of the tensor elements. /// /param element_type The type of the tensor elements.
...@@ -52,6 +55,7 @@ namespace ngraph ...@@ -52,6 +55,7 @@ namespace ngraph
const Shape& get_shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
...@@ -78,6 +82,7 @@ namespace ngraph ...@@ -78,6 +82,7 @@ namespace ngraph
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; } std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const ValueType& that) const override; virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
protected: protected:
std::vector<std::shared_ptr<ValueType>> m_element_types; std::vector<std::shared_ptr<ValueType>> m_element_types;
......
...@@ -34,13 +34,46 @@ void Visualize::add(node_ptr p) ...@@ -34,13 +34,46 @@ void Visualize::add(node_ptr p)
traverse_nodes(p, [&](node_ptr node) { traverse_nodes(p, [&](node_ptr node) {
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n"; m_ss << add_attributes(arg);
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
} }
}); });
} }
std::string Visualize::add_attributes(node_ptr node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
{
m_nodes_with_attributes.insert(node);
rc = get_attributes(node);
}
return rc;
}
std::string Visualize::get_attributes(node_ptr node)
{
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else if (node->is_op())
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=diamond color=red]\n";
}
return ss.str();
}
void Visualize::save_dot(const string& path) const void Visualize::save_dot(const string& path) const
{ {
#if GRAPHVIZ_FOUND
auto tmp_file = path + ".tmp"; auto tmp_file = path + ".tmp";
ofstream out(tmp_file); ofstream out(tmp_file);
if (out) if (out)
...@@ -56,6 +89,8 @@ void Visualize::save_dot(const string& path) const ...@@ -56,6 +89,8 @@ void Visualize::save_dot(const string& path) const
auto stream = popen(cmd.c_str(), "r"); auto stream = popen(cmd.c_str(), "r");
pclose(stream); pclose(stream);
// remove(tmp_file.c_str()); remove(tmp_file.c_str());
} }
#else
#endif
} }
...@@ -36,6 +36,10 @@ public: ...@@ -36,6 +36,10 @@ public:
void save_dot(const std::string& path) const; void save_dot(const std::string& path) const;
private: private:
std::stringstream m_ss; std::string add_attributes(node_ptr node);
std::string m_name; std::string get_attributes(node_ptr node);
std::stringstream m_ss;
std::string m_name;
std::set<node_ptr> m_nodes_with_attributes;
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (*arg0_tensor_type != *arg1_tensor_type)
{
throw ngraph_error("Arguments must have the same tensor view type");
}
set_value_type_checked(arg0_tensor_type);
}
...@@ -34,12 +34,12 @@ void Dot::propagate_types() ...@@ -34,12 +34,12 @@ void Dot::propagate_types()
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->get_shape(); vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction; size_t arg1_reduction;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1) if (arg1_shape.size() > 1)
{ {
arg1_reduction = arg1_shape.size() - 2; arg1_reduction = arg1_shape.size() - 2;
...@@ -48,13 +48,30 @@ void Dot::propagate_types() ...@@ -48,13 +48,30 @@ void Dot::propagate_types()
{ {
arg1_reduction = arg1_shape.size() - 1; arg1_reduction = arg1_shape.size() - 1;
} }
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)) if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{ {
throw ngraph_error("Dot reduction axes not compatible"); throw ngraph_error("Dot reduction axes not compatible");
} }
vector<size_t> result_shape; vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end()); result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end()); for(auto i = 0; i < arg0_shape.size(); i++)
set_value_type_checked(make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape)); {
if(is_scalar_mult || i != arg0_reduction)
{
result_shape.push_back(arg0_shape[i]);
}
}
for(auto i = 0; i < arg1_shape.size(); i++)
{
if(is_scalar_mult || i != arg1_reduction)
{
result_shape.push_back(arg1_shape[i]);
}
}
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type);
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
set_value_type_checked(arg_tensor_type);
}
...@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const ...@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const
return true; return true;
} }
void TensorViewType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
views.push_back(shared_from_this());
}
bool TupleType::operator==(const ValueType& that) const bool TupleType::operator==(const ValueType& that) const
{ {
auto that_tvt = dynamic_cast<const TupleType*>(&that); auto that_tvt = dynamic_cast<const TupleType*>(&that);
...@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const ...@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const
} }
return that_tvt->get_element_types() == get_element_types(); return that_tvt->get_element_types() == get_element_types();
} }
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
for(auto elt : m_element_types){
elt->collect_tensor_views(views);
}
}
...@@ -29,15 +29,16 @@ link_directories( ...@@ -29,15 +29,16 @@ link_directories(
set (SRC set (SRC
main.cpp main.cpp
build_graph.cpp build_graph.cpp
util.cpp eigen.cpp
tensor.cpp
element_type.cpp element_type.cpp
uuid.cpp mkldnn.cpp
op.cpp
input_output_assign.cpp
tensor.cpp
topological_sort.cpp topological_sort.cpp
type_prop.cpp type_prop.cpp
op.cpp util.cpp
eigen.cpp uuid.cpp
mkldnn.cpp
) )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(input_output, param_tensor)
{
// Params have no arguments, so we can check that the value becomes a tensor output
auto tv_tp = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param = make_shared<op::Parameter>(tv_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 1);
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
}
ASSERT_EQ(*tv_tp, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, param_tuple)
{
// Same as param_tensor, but for a tuple
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto tv_tp_1 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4, 6});
auto tp_tp = make_shared<TupleType>(std::vector<std::shared_ptr<ValueType>>{tv_tp_0, tv_tp_1});
auto param = make_shared<op::Parameter>(tp_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 2);
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
}
ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1]->get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, simple_output)
{
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param_0 = make_shared<op::Parameter>(tv_tp_0);
auto param_1 = make_shared<op::Parameter>(tv_tp_0);
auto add = make_shared<op::Add>(param_0, param_1);
// Sort the ops
vector<shared_ptr<Node>> nodes;
nodes.push_back(param_0);
nodes.push_back(param_1);
nodes.push_back(add);
// Type info
for (auto node : nodes)
{
node->propagate_types();
}
// Add inputs/outputs
for (auto node : nodes)
{
node->assign_tensors();
}
// At this point, the add should have each input associated with the output of the appropriate parameter
auto inputs = add->get_inputs();
ASSERT_EQ(2, inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
auto input = inputs[i];
ASSERT_EQ(i, input->get_index());
ASSERT_EQ(i, input->get_argno());
ASSERT_EQ(0, input->get_arg_index());
ASSERT_EQ(input->get_output()->get_node(), add->get_arguments()[i]);
}
}
...@@ -22,11 +22,12 @@ ...@@ -22,11 +22,12 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp" #include "ngraph/topological_sort.hpp"
#include "ngraph/visualize.hpp" #include "ngraph/visualize.hpp"
#include "util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes) static bool validate_list(const list<Node*>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
...@@ -38,7 +39,7 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -38,7 +39,7 @@ static bool validate_list(const vector<Node*>& nodes)
{ {
dependencies.push_back(n.get()); dependencies.push_back(n.get());
} }
auto tmp = it + 1; auto tmp = it++;
for (; tmp != nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
...@@ -87,12 +88,18 @@ TEST(topological_sort, basic) ...@@ -87,12 +88,18 @@ TEST(topological_sort, basic)
ASSERT_EQ(2, r0->get_arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0); auto op_r0 = static_pointer_cast<Op>(r0);
Visualize vz; // Visualize vz;
vz.add(r0); // vz.add(r0);
vz.save_dot("test.png"); // vz.save_dot("test.png");
TopologicalSort ts; TopologicalSort ts;
ts.process(r0); ts.process(r0);
auto sorted_list = ts.get_sorted_list(); auto sorted_list = ts.get_sorted_list();
size_t node_count = 0;
traverse_nodes(r0, [&](node_ptr node) {
node_count++;
});
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list)); EXPECT_TRUE(validate_list(sorted_list));
} }
...@@ -20,6 +20,14 @@ ...@@ -20,6 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node);
void test_binary_bad_arguments_views(const shared_ptr<Node>& node);
void test_binary_good_arguments(const shared_ptr<Node>& node);
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y));
//
// Tests for broadcast.
//
TEST(type_prop, broadcast_deduce) TEST(type_prop, broadcast_deduce)
{ {
// Deduce type // Deduce type
...@@ -52,6 +60,7 @@ TEST(type_prop, broadcast_deduce_incorrect) ...@@ -52,6 +60,7 @@ TEST(type_prop, broadcast_deduce_incorrect)
try try
{ {
bc->propagate_types(); bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type"; FAIL() << "Deduced type should disagree with specified type";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
...@@ -72,6 +81,7 @@ TEST(type_prop, broadcast_bad_arguments) ...@@ -72,6 +81,7 @@ TEST(type_prop, broadcast_bad_arguments)
try try
{ {
bc->propagate_types(); bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument to broadcast not detected."; FAIL() << "Tuple argument to broadcast not detected.";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
...@@ -84,3 +94,245 @@ TEST(type_prop, broadcast_bad_arguments) ...@@ -84,3 +94,245 @@ TEST(type_prop, broadcast_bad_arguments)
} }
} }
//
// Tests for dot product.
//
TEST(type_prop, dot_deduce_scalar_2d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_2d_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_scalar_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_scalar_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{6});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{6}));
}
TEST(type_prop, dot_deduce_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_2d)
{
// Deduce type for 2D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,3}));
}
TEST(type_prop, dot_deduce_different_d)
{
// Deduce type for different-dimension arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,8,4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1,2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2,8,4,1,3}));
}
TEST(type_prop, dot_deduce_different_d_correct)
{
// Deduced type matches explicitly set type
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,8,4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1,2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2,8,4,1,3}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2,8,4,1,3}));
}
TEST(type_prop, dot_deduce_element_type_mismatch)
{
// Type deduction fails due to element type mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Int32::element_type(), Shape{2,5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Element type mismatch not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments to dot must have the same element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
{
// Type deduction fails due to reduction axes size mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3,5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Dot reduction axes size mismatch not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Dot reduction axes not compatible"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
//
// Tests for binary elementwise ops.
//
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must be tensor views"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_good_arguments(const shared_ptr<Node>& node)
{
node->propagate_types();
EXPECT_EQ(*node->get_value_type(), *node->get_arguments()[0]->get_value_type());
}
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tp0_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tp1_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_views(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
}
TEST(type_prop, add_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Add>(x, y);
});
}
TEST(type_prop, ceiling_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Ceiling>(x, y);
});
}
TEST(type_prop, divide_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Divide>(x, y);
});
}
TEST(type_prop, floor_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Floor>(x, y);
});
}
TEST(type_prop, multiply_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Multiply>(x, y);
});
}
TEST(type_prop, subtract_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Subtract>(x, y);
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment