Commit 305a9a8a authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/ninja

parents c5787cc5 f1608316
find_program(GRAPHVIZ_EXECUTABLE dot)
# Handle REQUIRED and QUIET arguments
# this will also set GRAPHVIZ_FOUND to true if GRAPHVIZ_EXECUTABLE exists
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Graphviz
"Failed to locate graphviz executable"
GRAPHVIZ_EXECUTABLE)
......@@ -15,6 +15,11 @@ set (SRC
tree.cpp
util.cpp
log.cpp
ngraph/descriptor/input.cpp
ngraph/descriptor/output.cpp
ngraph/descriptor/tensor.cpp
ngraph/descriptor/tensor_view.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/convert.cpp
......@@ -24,6 +29,7 @@ set (SRC
ops/op.cpp
ops/parameter.cpp
ops/tuple.cpp
ops/unary_elementwise_builtin.cpp
types/element_type.cpp
types/type.cpp
ngraph/node.cpp
......@@ -36,15 +42,22 @@ set(NGRAPH_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/ngraph
)
# find_program (GRAPHVIZ dot)
# message (STATUS "graphviz '${GRAPHVIZ}'")
find_package(Graphviz)
if (GRAPHVIZ_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DGRAPHVIZ_FOUND")
endif()
include_directories("${NGRAPH_INCLUDE_PATH}")
add_library(ngraph SHARED ${SRC})
target_include_directories(ngraph PUBLIC "${NGRAPH_INCLUDE_PATH}")
if ( APPLE )
set_property( TARGET ngraph PROPERTY PREFIX "lib" )
set_property( TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so" )
set_property( TARGET ngraph PROPERTY SUFFIX "" )
if (APPLE)
set_property(TARGET ngraph PROPERTY PREFIX "lib")
set_property(TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so")
set_property(TARGET ngraph PROPERTY SUFFIX "")
endif()
#-----------------------------------------------------------------------------------------------
......@@ -63,7 +76,7 @@ set(CMAKE_INSTALL_PREFIX "$ENV{HOME}/ngraph_dist" CACHE PATH "Install directory"
message (STATUS "Installation directory: ${CMAKE_INSTALL_PREFIX}")
message (STATUS "To Override use: cmake -DCMAKE_INSTALL_PREFIX=/foo -P cmake_install.cmake")
install(TARGETS ngraph DESTINATION ${CMAKE_INSTALL_PREFIX} )
install(TARGETS ngraph DESTINATION ${CMAKE_INSTALL_PREFIX})
install(DIRECTORY
${CMAKE_CURRENT_SOURCE_DIR}/ngraph
DESTINATION ${CMAKE_INSTALL_PREFIX}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace descriptor;
Input::Input(
Node* node, size_t index, size_t argno, size_t arg_index, const shared_ptr<Output>& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
, m_arg_index(arg_index)
, m_output(output)
{
output->add_input(this);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
namespace ngraph
{
namespace descriptor
{
class Output;
// Describes a tensor that is an input to an op, directly or indirectly via a tuple
class Input : public std::enable_shared_from_this<Input>
{
Input(const Input&) = delete;
Input& operator=(const Input&) = delete;
public:
/// @param node The node that owns this input; not shared to prevent owner loop
/// @param index The position of this this tensor in all input tensors
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(Node* node,
size_t index,
size_t argno,
size_t arg_index,
const std::shared_ptr<Output>& output);
std::shared_ptr<Node> get_node() { return m_node->shared_from_this(); }
size_t get_argno() const { return m_argno; }
size_t get_arg_index() const { return m_arg_index; }
size_t get_index() const { return m_index; }
std::shared_ptr<Output> get_output() const { return m_output; }
protected:
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
std::shared_ptr<Output> m_output;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace descriptor;
Output::Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view)
: m_node(node)
, m_index(index)
, m_tensor_view(tensor_view)
{
}
// Add an input to the vector of inputs that use this output.
void Output::add_input(Input* input)
{
m_inputs.insert(input);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <set>
#include "descriptor/tensor_view.hpp"
namespace ngraph
{
namespace descriptor
{
// Describes an output tensor of an op
class Output : public std::enable_shared_from_this<Output>
{
Output(const Output&) = delete;
Output& operator=(const Output&) = delete;
public:
/// @param node Node that owns this output. Not shared to prevent owner loop.
/// @param index Position of the output tensor in all output tensors
/// @param tensor_view The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
std::shared_ptr<Node> get_node() const { return m_node->shared_from_this(); }
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
protected:
Node* m_node;
size_t m_index;
std::shared_ptr<TensorView> m_tensor_view;
std::set<Input*> m_inputs;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor.hpp"
using namespace ngraph;
using namespace descriptor;
Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view)
: m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
namespace ngraph
{
namespace element
{
class Type;
}
namespace descriptor
{
class TensorView;
class PrimaryTensorView;
class Tensor
{
friend class PrimaryTensorView;
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view);
protected:
const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor_view.hpp"
using namespace ngraph;
using namespace descriptor;
const Tensor& PrimaryTensorView::get_tensor() const
{
return m_tensor;
}
Tensor& PrimaryTensorView::get_tensor()
{
return m_tensor;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "descriptor/tensor.hpp"
#include "shape.hpp"
#include "type.hpp"
namespace ngraph
{
namespace descriptor
{
class Tensor;
class TensorViewLayout;
// Describes a view of an instantiated tensor
class TensorView : public std::enable_shared_from_this<TensorView>
{
TensorView(const TensorView&) = delete;
TensorView& operator=(const TensorView&) = delete;
protected:
TensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type)
: m_tensor_view_type(tensor_view_type)
{
}
public:
virtual ~TensorView() {}
virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0;
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return m_tensor_view_type;
}
const std::shared_ptr<TensorViewLayout>& get_tensor_view_layout() const
{
return m_tensor_view_layout;
}
void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout)
{
m_tensor_view_layout = tensor_view_layout;
}
protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type;
std::shared_ptr<TensorViewLayout> m_tensor_view_layout;
};
// A PrimaryTensorView owns the tensor. All other views are the result
// of some index operation on the primary view.
class PrimaryTensorView : public TensorView
{
public:
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type)
: TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this)
{
}
virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override;
protected:
Tensor m_tensor;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
namespace ngraph
{
namespace descriptor
{
using Strides = std::vector<size_t>;
class TensorViewLayout
{
protected:
Strides m_strides;
};
}
}
......@@ -23,13 +23,24 @@
#include "except.hpp"
#include "function.hpp"
#include "node.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "descriptor/tensor_view.hpp"
#include "descriptor/tensor_view_layout.hpp"
#include "descriptor/tensor.hpp"
#include "op.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp"
#include "ops/ceiling.hpp"
#include "ops/concatenate.hpp"
#include "ops/constant.hpp"
#include "ops/convert.hpp"
#include "ops/divide.hpp"
#include "ops/dot.hpp"
#include "ops/floor.hpp"
#include "ops/multiply.hpp"
#include "ops/parameter.hpp"
#include "ops/subtract.hpp"
#include "ops/tuple.hpp"
#include "shape.hpp"
#include "type.hpp"
......@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "node.hpp"
#include "op.hpp"
#include "ngraph.hpp"
size_t ngraph::Node::m_next_instance_id = 0;
using namespace std;
using namespace ngraph;
ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
std::shared_ptr<ValueType> value_type)
size_t Node::m_next_instance_id = 0;
Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<ValueType> value_type)
: m_arguments(arguments)
, m_value_type(value_type)
, m_instance_id(m_next_instance_id++)
......@@ -30,33 +31,63 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
}
}
void ngraph::Node::set_value_type_checked(const std::shared_ptr<ValueType>& value_type)
void Node::set_value_type_checked(const shared_ptr<ValueType>& value_type)
{
if (nullptr == m_value_type){
if (nullptr == m_value_type)
{
m_value_type = value_type;
} else {
if (*m_value_type != *value_type){
throw ngraph::ngraph_error("Setting value type to a different ValueType");
}
else
{
if (*m_value_type != *value_type)
{
throw ngraph_error("Setting value type to a different ValueType");
}
}
}
void Node::assign_tensors()
{
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
size_t i = 0;
for (auto tvt : tensor_view_types)
{
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt);
auto output = make_shared<descriptor::Output>(this, i++, tensor_view_descriptor);
m_outputs.push_back(output);
}
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (auto output : arg->get_outputs())
{
auto input = make_shared<descriptor::Input>(this, i++, argno, arg_index++, output);
m_inputs.push_back(input);
}
argno++;
}
}
bool ngraph::Node::is_op() const
bool Node::is_op() const
{
return dynamic_cast<const ngraph::Op*>(this) != nullptr;
return dynamic_cast<const Op*>(this) != nullptr;
}
bool ngraph::Node::is_parameter() const
bool Node::is_parameter() const
{
return dynamic_cast<const ngraph::op::Parameter*>(this) != nullptr;
return dynamic_cast<const op::Parameter*>(this) != nullptr;
}
namespace ngraph
{
std::ostream& operator<<(std::ostream& out, const ngraph::Node& node)
ostream& operator<<(ostream& out, const Node& node)
{
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node);
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
auto op_tmp = dynamic_cast<const Op*>(&node);
auto parameter_tmp = dynamic_cast<const Op*>(&node);
if (op_tmp)
{
out << "Op(" << op_tmp->get_node_id() << ")";
......
......@@ -27,6 +27,12 @@ namespace ngraph
{
class Op;
namespace descriptor
{
class Input;
class Output;
}
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values.
......@@ -53,6 +59,10 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; }
......@@ -95,6 +105,8 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<std::shared_ptr<descriptor::Input>> get_inputs() { return m_inputs; }
std::vector<std::shared_ptr<descriptor::Output>> get_outputs() {return m_outputs; }
protected:
Nodes m_arguments;
......@@ -103,5 +115,7 @@ namespace ngraph
std::string m_name;
size_t m_instance_id;
static size_t m_next_instance_id;
std::vector<std::shared_ptr<descriptor::Input>> m_inputs;
std::vector<std::shared_ptr<descriptor::Output>> m_outputs;
};
}
......@@ -61,10 +61,6 @@ namespace ngraph
{
public:
virtual std::string description() const override { return "Builtin"; }
/// Name of the builtin op, for debugging and logging.
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args)
......@@ -73,58 +69,60 @@ namespace ngraph
}
};
class Abs : public Builtin
/// Index ops create a new way to index the same tensor elements
class IndexBuiltin : public Builtin
{
public:
Abs(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
protected:
IndexBuiltin(const std::shared_ptr<Node>& arg)
: Builtin(Nodes{arg})
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
};
class Add : public Builtin
/// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
protected:
UnaryElementwiseBuiltin(const std::shared_ptr<Node>& arg)
: Builtin(Nodes{arg})
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
//virtual void propagate_types() override;
public:
virtual void propagate_types() override;
};
class Ceiling : public Builtin
/// Op(X, Y)[I] = op(X[I], Y[I])
class BinaryElementwiseBuiltin : public Builtin
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
protected:
BinaryElementwiseBuiltin(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin(Nodes{arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
//virtual void propagate_types() override;
public:
virtual void propagate_types() override;
};
class Divide : public Builtin
class Abs : public UnaryElementwiseBuiltin
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
Abs(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin({arg0})
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
};
class Equal : public Builtin
class Equal : public BinaryElementwiseBuiltin
{
public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -132,11 +130,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Exp : public Builtin
class Exp : public UnaryElementwiseBuiltin
{
public:
Exp(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
: UnaryElementwiseBuiltin(arg0)
{
}
......@@ -144,23 +142,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Floor : public Builtin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
//virtual void propagate_types() override;
};
class Greater : public Builtin
class Greater : public BinaryElementwiseBuiltin
{
public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -168,11 +154,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Less : public Builtin
class Less : public BinaryElementwiseBuiltin
{
public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -180,11 +166,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Log : public Builtin
class Log : public UnaryElementwiseBuiltin
{
public:
Log(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
: UnaryElementwiseBuiltin(arg0)
{
}
......@@ -192,11 +178,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Maximum : public Builtin
class Maximum : public BinaryElementwiseBuiltin
{
public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -204,11 +190,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Minimum : public Builtin
class Minimum : public BinaryElementwiseBuiltin
{
public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -216,23 +202,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Multiply : public Builtin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
//virtual void propagate_types() override;
};
class Negative : public Builtin
class Negative : public UnaryElementwiseBuiltin
{
public:
Negative(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
: UnaryElementwiseBuiltin(arg0)
{
}
......@@ -240,11 +214,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Power : public Builtin
class Power : public BinaryElementwiseBuiltin
{
public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -252,11 +226,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Remainder : public Builtin
class Remainder : public BinaryElementwiseBuiltin
{
public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -264,11 +238,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Reshape : public Builtin
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: Builtin({arg0})
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
......@@ -278,17 +252,5 @@ namespace ngraph
protected:
Shape m_shape;
};
class Subtract : public Builtin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
//virtual void propagate_types() override;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Add : public BinaryElementwiseBuiltin
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
};
}
}
......@@ -18,7 +18,7 @@ namespace ngraph
{
namespace op
{
class Broadcast : public Builtin
class Broadcast : public IndexBuiltin
{
public:
///
......@@ -30,7 +30,7 @@ namespace ngraph
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Builtin({arg})
: IndexBuiltin(arg)
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Ceiling : public BinaryElementwiseBuiltin
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
};
}
}
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Convert : public Builtin
class Convert : public UnaryElementwiseBuiltin
{
public:
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: Builtin({arg})
: UnaryElementwiseBuiltin({arg})
, m_element_type(element_type)
{
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Divide : public BinaryElementwiseBuiltin
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
};
}
}
......@@ -21,7 +21,25 @@ namespace ngraph
class Dot : public Builtin
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
/// Computes the dot product of two tensors.
///
/// There are three possible cases:
/// (1) arg0 or arg1 is 0-dimensional. Then, we treat the 0-dimensional
/// argument(s) as scalars and compute a scalar-tensor or
/// scalar-scalar product.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then
/// the result will have shape {1,2,3}.)
///
/// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing
/// on the innermost (rightmost) dimensions of arg0 and arg1.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then
/// the result will have shape {1,2}.)
///
/// (3) arg1 is more than 1-dimensional. Then, we compute a dot product
/// reducing on the innermost (rightmost) dimension of arg0, and the
/// next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then
/// the result will have shape {3,3}.)
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Floor : public BinaryElementwiseBuiltin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Multiply : public BinaryElementwiseBuiltin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Subtract : public BinaryElementwiseBuiltin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
};
}
}
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <unordered_map>
#include <deque>
#include "topological_sort.hpp"
#include "node.hpp"
#include "util.hpp"
......@@ -19,31 +22,19 @@
using namespace ngraph;
using namespace std;
void ngraph::TopologicalSort::promote_node(Node* n)
{
for (auto dn = m_dependent_nodes.begin(); dn != m_dependent_nodes.end(); dn++)
{
if (dn->first > 0) // Skip zero as they should never be promoted
{
auto it = find(dn->second.begin(), dn->second.end(), n);
if (it != dn->second.end())
{
// found the node
dn->second.erase(it);
m_dependent_nodes[dn->first - 1].push_back(n);
}
}
}
}
void ngraph::TopologicalSort::process(node_ptr p)
{
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](node_ptr node) {
list<Node*>& node_list = m_dependent_nodes[node->get_arguments().size()];
node_list.push_back(node.get());
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
});
list<Node*>& independent_nodes = m_dependent_nodes[0];
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
......@@ -52,12 +43,22 @@ void ngraph::TopologicalSort::process(node_ptr p)
for (auto user : independent_node->users())
{
promote_node(user);
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
}
const std::vector<Node*>& ngraph::TopologicalSort::get_sorted_list() const
const std::list<Node*>& ngraph::TopologicalSort::get_sorted_list() const
{
return m_sorted_list;
}
std::list<Node*>& ngraph::TopologicalSort::get_sorted_list()
{
return m_sorted_list;
}
......@@ -14,10 +14,8 @@
#pragma once
#include <list>
#include <map>
#include <memory>
#include <vector>
#include <list>
namespace ngraph
{
......@@ -32,11 +30,11 @@ public:
TopologicalSort() {}
void process(node_ptr);
const std::vector<Node*>& get_sorted_list() const;
const std::list<Node*>& get_sorted_list() const;
std::list<Node*>& get_sorted_list();
private:
void promote_node(Node* n);
std::map<size_t, std::list<Node*>> m_dependent_nodes;
std::vector<Node*> m_sorted_list;
std::list<Node*> m_sorted_list;
};
......@@ -34,10 +34,13 @@ namespace ngraph
virtual ~ValueType() {}
virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const { return !(*this == that); }
/// Add tensor views in depth-first order.
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
};
/// Describes a tensor view; an element type and a shape.
class TensorViewType : public ValueType
class TensorViewType : public ValueType, public std::enable_shared_from_this<TensorViewType>
{
public:
/// /param element_type The type of the tensor elements.
......@@ -52,6 +55,7 @@ namespace ngraph
const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
protected:
const element::Type& m_element_type;
......@@ -78,6 +82,7 @@ namespace ngraph
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
......
......@@ -34,13 +34,46 @@ void Visualize::add(node_ptr p)
traverse_nodes(p, [&](node_ptr node) {
for (auto arg : node->get_arguments())
{
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
m_ss << add_attributes(arg);
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
}
});
}
std::string Visualize::add_attributes(node_ptr node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
{
m_nodes_with_attributes.insert(node);
rc = get_attributes(node);
}
return rc;
}
std::string Visualize::get_attributes(node_ptr node)
{
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else if (node->is_op())
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=diamond color=red]\n";
}
return ss.str();
}
void Visualize::save_dot(const string& path) const
{
#if GRAPHVIZ_FOUND
auto tmp_file = path + ".tmp";
ofstream out(tmp_file);
if (out)
......@@ -56,6 +89,8 @@ void Visualize::save_dot(const string& path) const
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
// remove(tmp_file.c_str());
remove(tmp_file.c_str());
}
#else
#endif
}
......@@ -36,6 +36,10 @@ public:
void save_dot(const std::string& path) const;
private:
std::string add_attributes(node_ptr node);
std::string get_attributes(node_ptr node);
std::stringstream m_ss;
std::string m_name;
std::set<node_ptr> m_nodes_with_attributes;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (*arg0_tensor_type != *arg1_tensor_type)
{
throw ngraph_error("Arguments must have the same tensor view type");
}
set_value_type_checked(arg0_tensor_type);
}
......@@ -34,12 +34,12 @@ void Dot::propagate_types()
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
......@@ -48,13 +48,30 @@ void Dot::propagate_types()
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
set_value_type_checked(make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape));
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
for(auto i = 0; i < arg0_shape.size(); i++)
{
if(is_scalar_mult || i != arg0_reduction)
{
result_shape.push_back(arg0_shape[i]);
}
}
for(auto i = 0; i < arg1_shape.size(); i++)
{
if(is_scalar_mult || i != arg1_reduction)
{
result_shape.push_back(arg1_shape[i]);
}
}
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
set_value_type_checked(arg_tensor_type);
}
......@@ -37,6 +37,11 @@ bool TensorViewType::operator==(const ValueType& that) const
return true;
}
void TensorViewType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
views.push_back(shared_from_this());
}
bool TupleType::operator==(const ValueType& that) const
{
auto that_tvt = dynamic_cast<const TupleType*>(&that);
......@@ -46,3 +51,10 @@ bool TupleType::operator==(const ValueType& that) const
}
return that_tvt->get_element_types() == get_element_types();
}
void TupleType::collect_tensor_views(std::vector<std::shared_ptr<const TensorViewType>>& views) const
{
for(auto elt : m_element_types){
elt->collect_tensor_views(views);
}
}
......@@ -29,15 +29,16 @@ link_directories(
set (SRC
main.cpp
build_graph.cpp
util.cpp
tensor.cpp
eigen.cpp
element_type.cpp
uuid.cpp
mkldnn.cpp
op.cpp
input_output_assign.cpp
tensor.cpp
topological_sort.cpp
type_prop.cpp
op.cpp
eigen.cpp
mkldnn.cpp
util.cpp
uuid.cpp
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(input_output, param_tensor)
{
// Params have no arguments, so we can check that the value becomes a tensor output
auto tv_tp = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param = make_shared<op::Parameter>(tv_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 1);
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
}
ASSERT_EQ(*tv_tp, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, param_tuple)
{
// Same as param_tensor, but for a tuple
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto tv_tp_1 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4, 6});
auto tp_tp = make_shared<TupleType>(std::vector<std::shared_ptr<ValueType>>{tv_tp_0, tv_tp_1});
auto param = make_shared<op::Parameter>(tp_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 2);
for (size_t i = 0; i < param->get_outputs().size(); i++)
{
auto output = param->get_outputs()[i];
ASSERT_EQ(i, output->get_index());
ASSERT_EQ(param, output->get_node());
}
ASSERT_EQ(*tv_tp_0, *param->get_outputs()[0]->get_tensor_view()->get_tensor_view_type());
ASSERT_EQ(*tv_tp_1, *param->get_outputs()[1]->get_tensor_view()->get_tensor_view_type());
}
TEST(input_output, simple_output)
{
auto tv_tp_0 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param_0 = make_shared<op::Parameter>(tv_tp_0);
auto param_1 = make_shared<op::Parameter>(tv_tp_0);
auto add = make_shared<op::Add>(param_0, param_1);
// Sort the ops
vector<shared_ptr<Node>> nodes;
nodes.push_back(param_0);
nodes.push_back(param_1);
nodes.push_back(add);
// Type info
for (auto node : nodes)
{
node->propagate_types();
}
// Add inputs/outputs
for (auto node : nodes)
{
node->assign_tensors();
}
// At this point, the add should have each input associated with the output of the appropriate parameter
auto inputs = add->get_inputs();
ASSERT_EQ(2, inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
auto input = inputs[i];
ASSERT_EQ(i, input->get_index());
ASSERT_EQ(i, input->get_argno());
ASSERT_EQ(0, input->get_arg_index());
ASSERT_EQ(input->get_output()->get_node(), add->get_arguments()[i]);
}
}
......@@ -22,11 +22,12 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/topological_sort.hpp"
#include "ngraph/visualize.hpp"
#include "util.hpp"
using namespace std;
using namespace ngraph;
static bool validate_list(const vector<Node*>& nodes)
static bool validate_list(const list<Node*>& nodes)
{
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
......@@ -38,7 +39,7 @@ static bool validate_list(const vector<Node*>& nodes)
{
dependencies.push_back(n.get());
}
auto tmp = it + 1;
auto tmp = it++;
for (; tmp != nodes.rend(); tmp++)
{
auto dep_tmp = *tmp;
......@@ -87,12 +88,18 @@ TEST(topological_sort, basic)
ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
Visualize vz;
vz.add(r0);
vz.save_dot("test.png");
// Visualize vz;
// vz.add(r0);
// vz.save_dot("test.png");
TopologicalSort ts;
ts.process(r0);
auto sorted_list = ts.get_sorted_list();
size_t node_count = 0;
traverse_nodes(r0, [&](node_ptr node) {
node_count++;
});
EXPECT_EQ(node_count, sorted_list.size());
EXPECT_TRUE(validate_list(sorted_list));
}
......@@ -20,6 +20,14 @@
using namespace std;
using namespace ngraph;
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node);
void test_binary_bad_arguments_views(const shared_ptr<Node>& node);
void test_binary_good_arguments(const shared_ptr<Node>& node);
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y));
//
// Tests for broadcast.
//
TEST(type_prop, broadcast_deduce)
{
// Deduce type
......@@ -52,6 +60,7 @@ TEST(type_prop, broadcast_deduce_incorrect)
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
catch (const ngraph_error& error)
......@@ -72,6 +81,7 @@ TEST(type_prop, broadcast_bad_arguments)
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument to broadcast not detected.";
}
catch (const ngraph_error& error)
......@@ -84,3 +94,245 @@ TEST(type_prop, broadcast_bad_arguments)
}
}
//
// Tests for dot product.
//
TEST(type_prop, dot_deduce_scalar_2d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_2d_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_scalar_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_scalar_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{6});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{6}));
}
TEST(type_prop, dot_deduce_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_2d)
{
// Deduce type for 2D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,3}));
}
TEST(type_prop, dot_deduce_different_d)
{
// Deduce type for different-dimension arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,8,4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1,2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2,8,4,1,3}));
}
TEST(type_prop, dot_deduce_different_d_correct)
{
// Deduced type matches explicitly set type
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,8,4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1,2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2,8,4,1,3}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2,8,4,1,3}));
}
TEST(type_prop, dot_deduce_element_type_mismatch)
{
// Type deduction fails due to element type mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Int32::element_type(), Shape{2,5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Element type mismatch not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments to dot must have the same element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
{
// Type deduction fails due to reduction axes size mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3,5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Dot reduction axes size mismatch not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Dot reduction axes not compatible"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
//
// Tests for binary elementwise ops.
//
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must be tensor views"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_good_arguments(const shared_ptr<Node>& node)
{
node->propagate_types();
EXPECT_EQ(*node->get_value_type(), *node->get_arguments()[0]->get_value_type());
}
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tp0_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tp1_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_views(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
}
TEST(type_prop, add_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Add>(x, y);
});
}
TEST(type_prop, ceiling_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Ceiling>(x, y);
});
}
TEST(type_prop, divide_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Divide>(x, y);
});
}
TEST(type_prop, floor_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Floor>(x, y);
});
}
TEST(type_prop, multiply_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Multiply>(x, y);
});
}
TEST(type_prop, subtract_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Subtract>(x, y);
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment