Commit 34d84d3d authored by pthoreho's avatar pthoreho

Merge remote-tracking branch 'origin' into pruthvi/mkldnn_elementwise_add

parents 9a376ca6 7f08b97b
...@@ -30,20 +30,24 @@ if (NGRAPH_ARGON_ENABLE) ...@@ -30,20 +30,24 @@ if (NGRAPH_ARGON_ENABLE)
endif() endif()
# Repository # Repository
set(ARGON_TRANSFORMER_CMAKE_GIT_REPOSITORY git@github.com:NervanaSystems/argon-transformer.git) if (DEFINED CUSTOM_ARGON_TRANSFORMER_GIT_REPOSITORY)
set(ARGON_TRANSFORMER_GIT_REPOSITORY ${CUSTOM_ARGON_TRANSFORMER_GIT_REPOSITORY})
else()
set(ARGON_TRANSFORMER_GIT_REPOSITORY git@github.com:NervanaSystems/argon-transformer.git)
endif()
# Set argon_transformer tag # Set argon_transformer tag
# Notes: # Notes:
# - Before we have ngraph CI job for argon transformer, ngraph master might not be # - Before we have ngraph CI job for argon transformer, ngraph master might not be
# compatible with argon transformer. To ensure compatibility, checkout the ngraph commit point # compatible with argon transformer. To ensure compatibility, checkout the ngraph commit point
# where the following `ARGON_TRANSFORMER_CMAKE_GIT_TAG` is set and build ngraph with argon using this # where the following `ARGON_TRANSFORMER_GIT_TAG` is set and build ngraph with argon using this
# commit. # commit.
# - After we have ngraph CI job for argon transformer, ngraph master will be compatible with # - After we have ngraph CI job for argon transformer, ngraph master will be compatible with
# argon transformer guaranteed by CI. # argon transformer guaranteed by CI.
set(ARGON_TRANSFORMER_CMAKE_GIT_TAG cpp-master) set(ARGON_TRANSFORMER_GIT_TAG cpp-master)
# Determines where argon-transformer will be located # Determines where argon-transformer will be located
set(ARGON_TRANSFORMER_CMAKE_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/argon_transformer) set(ARGON_TRANSFORMER_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/argon_transformer)
# Print # Print
message(STATUS "NGRAPH_INCLUDE_PATH: ${NGRAPH_INCLUDE_PATH}") message(STATUS "NGRAPH_INCLUDE_PATH: ${NGRAPH_INCLUDE_PATH}")
...@@ -56,27 +60,35 @@ if (NGRAPH_ARGON_ENABLE) ...@@ -56,27 +60,35 @@ if (NGRAPH_ARGON_ENABLE)
ExternalProject_Add( ExternalProject_Add(
ext_argon_transformer ext_argon_transformer
SOURCE_DIR ${CUSTOM_ARGON_TRANSFORMER_DIR} SOURCE_DIR ${CUSTOM_ARGON_TRANSFORMER_DIR}
PREFIX ${ARGON_TRANSFORMER_CMAKE_PREFIX} PREFIX ${ARGON_TRANSFORMER_PREFIX}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_CMAKE_PREFIX} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_PREFIX}
-DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH} -DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH}
-DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH} -DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH}
-DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include -DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include
-DMKLDNN_INCLUDE_DIR=${MKLDNN_INCLUDE_DIR}
BUILD_ALWAYS 1 BUILD_ALWAYS 1
) )
else() else()
ExternalProject_Add( ExternalProject_Add(
ext_argon_transformer ext_argon_transformer
GIT_REPOSITORY ${ARGON_TRANSFORMER_CMAKE_GIT_REPOSITORY} GIT_REPOSITORY ${ARGON_TRANSFORMER_GIT_REPOSITORY}
GIT_TAG ${ARGON_TRANSFORMER_CMAKE_GIT_TAG} GIT_TAG ${ARGON_TRANSFORMER_GIT_TAG}
PREFIX ${ARGON_TRANSFORMER_CMAKE_PREFIX} PREFIX ${ARGON_TRANSFORMER_PREFIX}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_CMAKE_PREFIX} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_PREFIX}
-DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH} -DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH}
-DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH} -DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH}
-DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include -DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include
-DMKLDNN_INCLUDE_DIR=${MKLDNN_INCLUDE_DIR}
BUILD_ALWAYS 1 BUILD_ALWAYS 1
) )
endif() endif()
...@@ -85,29 +97,37 @@ if (NGRAPH_ARGON_ENABLE) ...@@ -85,29 +97,37 @@ if (NGRAPH_ARGON_ENABLE)
ExternalProject_Add( ExternalProject_Add(
ext_argon_transformer ext_argon_transformer
SOURCE_DIR ${CUSTOM_ARGON_TRANSFORMER_DIR} SOURCE_DIR ${CUSTOM_ARGON_TRANSFORMER_DIR}
PREFIX ${ARGON_TRANSFORMER_CMAKE_PREFIX} PREFIX ${ARGON_TRANSFORMER_PREFIX}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_CMAKE_PREFIX} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_PREFIX}
-DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH} -DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH}
-DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH} -DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH}
-DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include -DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include
BUILD_BYPRODUCTS ${ARGON_TRANSFORMER_CMAKE_PREFIX} -DMKLDNN_INCLUDE_DIR=${MKLDNN_INCLUDE_DIR}
BUILD_BYPRODUCTS ${ARGON_TRANSFORMER_PREFIX}
BUILD_ALWAYS 1 BUILD_ALWAYS 1
) )
else() else()
ExternalProject_Add( ExternalProject_Add(
ext_argon_transformer ext_argon_transformer
GIT_REPOSITORY ${ARGON_TRANSFORMER_CMAKE_GIT_REPOSITORY} GIT_REPOSITORY ${ARGON_TRANSFORMER_GIT_REPOSITORY}
GIT_TAG ${ARGON_TRANSFORMER_CMAKE_GIT_TAG} GIT_TAG ${ARGON_TRANSFORMER_GIT_TAG}
PREFIX ${ARGON_TRANSFORMER_CMAKE_PREFIX} PREFIX ${ARGON_TRANSFORMER_PREFIX}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_CMAKE_PREFIX} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DNGRAPH_INSTALL_PREFIX=${ARGON_TRANSFORMER_PREFIX}
-DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH} -DPREBUILD_ARGON_API_PATH=${NGRAPH_PREBUILD_ARGON_API_PATH}
-DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH} -DEXTERNAL_NGRAPH_INCLUDE_DIR=${NGRAPH_INCLUDE_PATH}
-DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include -DINSTALLED_HEADERS_PATH=${CMAKE_INSTALL_PREFIX}/include
BUILD_BYPRODUCTS ${ARGON_TRANSFORMER_CMAKE_PREFIX} -DMKLDNN_INCLUDE_DIR=${MKLDNN_INCLUDE_DIR}
BUILD_BYPRODUCTS ${ARGON_TRANSFORMER_PREFIX}
BUILD_ALWAYS 1 BUILD_ALWAYS 1
) )
endif() endif()
...@@ -115,8 +135,8 @@ if (NGRAPH_ARGON_ENABLE) ...@@ -115,8 +135,8 @@ if (NGRAPH_ARGON_ENABLE)
ExternalProject_Get_Property(ext_argon_transformer source_dir) ExternalProject_Get_Property(ext_argon_transformer source_dir)
set(ARGON_TRANSFORMER_SOURCE_DIR ${source_dir} PARENT_SCOPE) set(ARGON_TRANSFORMER_SOURCE_DIR ${source_dir} PARENT_SCOPE)
set(ARGON_TRANSFORMER_INCLUDE_DIR ${ARGON_TRANSFORMER_CMAKE_PREFIX}/include PARENT_SCOPE) set(ARGON_TRANSFORMER_INCLUDE_DIR ${ARGON_TRANSFORMER_PREFIX}/include PARENT_SCOPE)
set(ARGON_TRANSFORMER_LIB_DIR ${ARGON_TRANSFORMER_CMAKE_PREFIX}/lib PARENT_SCOPE) set(ARGON_TRANSFORMER_LIB_DIR ${ARGON_TRANSFORMER_PREFIX}/lib PARENT_SCOPE)
set(ARGON_API_INCLUDE_DIR ${NGRAPH_PREBUILD_ARGON_API_PATH}/include PARENT_SCOPE) set(ARGON_API_INCLUDE_DIR ${NGRAPH_PREBUILD_ARGON_API_PATH}/include PARENT_SCOPE)
set(ARGON_API_LIB_DIR ${NGRAPH_PREBUILD_ARGON_API_PATH}/lib) # Used by find_library below set(ARGON_API_LIB_DIR ${NGRAPH_PREBUILD_ARGON_API_PATH}/lib) # Used by find_library below
set(ARGON_API_LIB_DIR ${NGRAPH_PREBUILD_ARGON_API_PATH}/lib PARENT_SCOPE) set(ARGON_API_LIB_DIR ${NGRAPH_PREBUILD_ARGON_API_PATH}/lib PARENT_SCOPE)
......
...@@ -81,4 +81,8 @@ if(NGRAPH_CPU_ENABLE) ...@@ -81,4 +81,8 @@ if(NGRAPH_CPU_ENABLE)
set(MKLDNN_INCLUDE_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/include" PARENT_SCOPE) set(MKLDNN_INCLUDE_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/include" PARENT_SCOPE)
set(MKLDNN_LIB_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/lib" PARENT_SCOPE) set(MKLDNN_LIB_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/lib" PARENT_SCOPE)
# Other .cmake files in current scope (e.g. Argon Transformer) needs this path as well
set(MKLDNN_INCLUDE_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/include")
set(MKLDNN_LIB_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/lib")
endif() endif()
...@@ -42,12 +42,14 @@ bash_lib_status "Verified that '${CLANG_FORMAT_PROG}' has version '${REQUIRED_CL ...@@ -42,12 +42,14 @@ bash_lib_status "Verified that '${CLANG_FORMAT_PROG}' has version '${REQUIRED_CL
pushd "${THIS_SCRIPT_DIR}/.." pushd "${THIS_SCRIPT_DIR}/.."
declare ARGON_SRC_DIR="build/third-party/argon_transformer/src/ext_argon_transformer/src"
declare ARGON_TEST_DIR="build/third-party/argon_transformer/src/ext_argon_transformer/test"
declare ROOT_SUBDIR declare ROOT_SUBDIR
for ROOT_SUBDIR in src test; do for ROOT_SUBDIR in src test ${ARGON_SRC_DIR} ${ARGON_TEST_DIR}; do
if ! [[ -d "${ROOT_SUBDIR}" ]]; then if ! [[ -d "${ROOT_SUBDIR}" ]]; then
bash_lib_die "In directory '$(pwd)', no subdirectory named '${ROOT_SUBDIR}' was found." bash_lib_status "In directory '$(pwd)', no subdirectory named '${ROOT_SUBDIR}' was found."
fi else
bash_lib_status "About to format C/C++ code in directory tree '$(pwd)/${ROOT_SUBDIR}' ..." bash_lib_status "About to format C/C++ code in directory tree '$(pwd)/${ROOT_SUBDIR}' ..."
# Note that we restrict to "-type f" to exclude symlinks. Emacs sometimes # Note that we restrict to "-type f" to exclude symlinks. Emacs sometimes
...@@ -56,7 +58,7 @@ for ROOT_SUBDIR in src test; do ...@@ -56,7 +58,7 @@ for ROOT_SUBDIR in src test; do
find "${ROOT_SUBDIR}" -type f -and \( -name '*.cpp' -or -name '*.hpp' \) | xargs "${CLANG_FORMAT_PROG}" -i -style=file find "${ROOT_SUBDIR}" -type f -and \( -name '*.cpp' -or -name '*.hpp' \) | xargs "${CLANG_FORMAT_PROG}" -i -style=file
bash_lib_status "Done." bash_lib_status "Done."
fi
done done
popd popd
...@@ -45,12 +45,14 @@ declare NUM_FILES_CHECKED=0 ...@@ -45,12 +45,14 @@ declare NUM_FILES_CHECKED=0
pushd "${THIS_SCRIPT_DIR}/.." pushd "${THIS_SCRIPT_DIR}/.."
declare ARGON_SRC_DIR="build/third-party/argon_transformer/src/ext_argon_transformer/src"
declare ARGON_TEST_DIR="build/third-party/argon_transformer/src/ext_argon_transformer/test"
declare ROOT_SUBDIR declare ROOT_SUBDIR
for ROOT_SUBDIR in src test; do for ROOT_SUBDIR in src test ${ARGON_SRC_DIR} ${ARGON_TEST_DIR}; do
if ! [[ -d "${ROOT_SUBDIR}" ]]; then if ! [[ -d "${ROOT_SUBDIR}" ]]; then
bash_lib_die "In directory '$(pwd)', no subdirectory named '${ROOT_SUBDIR}' was found." bash_lib_status "In directory '$(pwd)', no subdirectory named '${ROOT_SUBDIR}' was found."
fi else
bash_lib_status "About to format C/C++ code in directory tree '$(pwd)/${ROOT_SUBDIR}' ..." bash_lib_status "About to format C/C++ code in directory tree '$(pwd)/${ROOT_SUBDIR}' ..."
declare SRC_FILE declare SRC_FILE
# Note that we restrict to "-type f" to exclude symlinks. Emacs sometimes # Note that we restrict to "-type f" to exclude symlinks. Emacs sometimes
...@@ -62,6 +64,7 @@ for ROOT_SUBDIR in src test; do ...@@ -62,6 +64,7 @@ for ROOT_SUBDIR in src test; do
fi fi
NUM_FILES_CHECKED=$((NUM_FILES_CHECKED+1)) NUM_FILES_CHECKED=$((NUM_FILES_CHECKED+1))
done done
fi
done done
popd popd
...@@ -76,4 +79,3 @@ else ...@@ -76,4 +79,3 @@ else
done done
exit 1 exit 1
fi fi
...@@ -84,6 +84,7 @@ set (SRC ...@@ -84,6 +84,7 @@ set (SRC
ops/util/requires_tensor_view_args.cpp ops/util/requires_tensor_view_args.cpp
ops/util/unary_elementwise_arithmetic.cpp ops/util/unary_elementwise_arithmetic.cpp
ops/util/unary_elementwise.cpp ops/util/unary_elementwise.cpp
pass/assign_placement.cpp
pass/dump_sorted.cpp pass/dump_sorted.cpp
pass/graph_rewrite.cpp pass/graph_rewrite.cpp
pass/inliner.cpp pass/inliner.cpp
...@@ -112,6 +113,7 @@ set (SRC ...@@ -112,6 +113,7 @@ set (SRC
types/type.cpp types/type.cpp
util.cpp util.cpp
graph_util.cpp graph_util.cpp
placement.cpp
) )
message(STATUS ${CMAKE_CURRENT_SOURCE_DIR}/ops) message(STATUS ${CMAKE_CURRENT_SOURCE_DIR}/ops)
......
This diff is collapsed.
...@@ -30,11 +30,24 @@ ...@@ -30,11 +30,24 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "ngraph/placement.hpp"
namespace ngraph namespace ngraph
{ {
class Node; class Node;
class Function; class Function;
namespace descriptor
{
class Input;
class Output;
}
namespace op
{
class Parameter;
}
void traverse_nodes(const std::shared_ptr<const Function> p, void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f); std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f); void traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f);
...@@ -60,7 +73,7 @@ namespace ngraph ...@@ -60,7 +73,7 @@ namespace ngraph
class NodeMap class NodeMap
{ {
public: public:
// map original node to replcacement node // map original node to replacement node
// throws ngraph_error if key already exists // throws ngraph_error if key already exists
void add(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> replacement); void add(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> replacement);
...@@ -100,4 +113,18 @@ namespace ngraph ...@@ -100,4 +113,18 @@ namespace ngraph
// NodeMap output (by reference) fully maps input and cloned function ops // NodeMap output (by reference) fully maps input and cloned function ops
std::shared_ptr<ngraph::Function> clone_function(std::shared_ptr<ngraph::Function> func, std::shared_ptr<ngraph::Function> clone_function(std::shared_ptr<ngraph::Function> func,
NodeMap& node_map); NodeMap& node_map);
// Assert that nodes in the function is colocated and return that placement
Placement get_colocated_function_placement(std::shared_ptr<Function> func);
// Split function to function(s) with unique placement
std::vector<std::shared_ptr<Function>> split_function_by_placement(
std::shared_ptr<Function> f,
std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<Node>>&
map_parameter_to_source_node);
// Insert parameter node between src_node and dst_node by splitting the graph
void insert_parameter_split_between(std::shared_ptr<Node> src_node,
std::shared_ptr<Node> dst_node,
std::shared_ptr<op::Parameter> p_node);
} }
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp" #include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/placement.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -144,6 +145,16 @@ void Node::set_name(const string& name) ...@@ -144,6 +145,16 @@ void Node::set_name(const string& name)
} }
} }
Placement Node::get_placement() const
{
return m_placement;
}
void Node::set_placement(Placement placement)
{
m_placement = placement;
}
std::shared_ptr<Node> Node::get_input_op(size_t index) std::shared_ptr<Node> Node::get_input_op(size_t index)
{ {
for (auto arg : m_arguments) for (auto arg : m_arguments)
...@@ -304,3 +315,27 @@ bool Node::has_same_type(std::shared_ptr<const Node> node) const ...@@ -304,3 +315,27 @@ bool Node::has_same_type(std::shared_ptr<const Node> node) const
} }
return true; return true;
} }
descriptor::Input* Node::get_input_from(const shared_ptr<Node>& src)
{
for (size_t i = 0; i < this->get_input_size(); ++i)
{
if (this->get_input_op(i) == src)
{
return &(this->get_inputs().at(i));
}
}
throw ngraph_error("Error: src is not one of self's input Node");
}
descriptor::Output* Node::get_output_to(const shared_ptr<Node>& dst)
{
for (size_t i = 0; i < dst->get_input_size(); ++i)
{
if (dst->get_input_op(i).get() == this)
{
return &(dst->get_inputs().at(i).get_output());
}
}
throw ngraph_error("Error: dst is not one of self's output Node");
}
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/placement.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
...@@ -39,6 +40,10 @@ namespace ngraph ...@@ -39,6 +40,10 @@ namespace ngraph
void replace_node_users_arguments(std::shared_ptr<Node> target, void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement); std::shared_ptr<Node> replacement);
void insert_parameter_split_between(std::shared_ptr<Node> src_node,
std::shared_ptr<Node> dst_node,
std::shared_ptr<op::Parameter> p_node);
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
...@@ -49,6 +54,9 @@ namespace ngraph ...@@ -49,6 +54,9 @@ namespace ngraph
friend class descriptor::Input; friend class descriptor::Input;
friend void replace_node_users_arguments(std::shared_ptr<Node> target, friend void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement); std::shared_ptr<Node> replacement);
friend void insert_parameter_split_between(std::shared_ptr<Node> src_node,
std::shared_ptr<Node> dst_node,
std::shared_ptr<op::Parameter> p_node);
protected: protected:
Node(const std::string& node_type, const Nodes& arguments); Node(const std::string& node_type, const Nodes& arguments);
...@@ -165,9 +173,21 @@ namespace ngraph ...@@ -165,9 +173,21 @@ namespace ngraph
virtual std::vector<std::shared_ptr<Function>> get_functions() const; virtual std::vector<std::shared_ptr<Function>> get_functions() const;
// True if this and node have one output with same element type and shape /// True if this and node have one output with same element type and shape
bool has_same_type(std::shared_ptr<const Node> node) const; bool has_same_type(std::shared_ptr<const Node> node) const;
/// Get device placement
Placement get_placement() const;
/// Set device placement
void set_placement(Placement placement);
/// Get input descriptor that is connected to src
descriptor::Input* get_input_from(const std::shared_ptr<Node>& src);
/// Get ouput descriptor that outputs to dst
descriptor::Output* get_output_to(const std::shared_ptr<Node>& dst);
protected: protected:
void add_output(const element::Type& element_type, const Shape& shape); void add_output(const element::Type& element_type, const Shape& shape);
...@@ -180,9 +200,11 @@ namespace ngraph ...@@ -180,9 +200,11 @@ namespace ngraph
std::deque<descriptor::Output> m_outputs; std::deque<descriptor::Output> m_outputs;
bool m_is_output; bool m_is_output;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
Placement m_placement = Placement::DEFAULT;
private: private:
Nodes m_arguments; Nodes m_arguments;
//m_arguments still needs to be kept in sync with i/o since get_input_ops //m_arguments still needs to be kept in sync with i/o since get_input_ops
//is pretty ubiquitous and might be called after the original graph was modified. //is pretty ubiquitous and might be called after the original graph was modified.
//get_input_ops uses m_arguments to check if a node view reconstruction from i/o //get_input_ops uses m_arguments to check if a node view reconstruction from i/o
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/pass/assign_placement.hpp"
#include "ngraph/node.hpp"
#include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph;
ngraph::pass::AssignPlacement::AssignPlacement(
std::function<Placement(std::shared_ptr<Node>)> placement_policy)
: m_placement_policy(placement_policy)
{
}
bool ngraph::pass::AssignPlacement::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{
for (const std::shared_ptr<Node>& node : nodes)
{
run_on_node(node);
}
return false;
}
bool ngraph::pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
{
node->set_placement(m_placement_policy(node));
return false;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <exception>
#include <sstream>
#include "ngraph/pass/pass.hpp"
#include "ngraph/placement.hpp"
namespace ngraph
{
namespace pass
{
class AssignPlacement : public CallGraphPass
{
public:
// TODO: make policy a class
AssignPlacement(std::function<Placement(std::shared_ptr<Node>)> placement_policy);
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
private:
bool run_on_node(std::shared_ptr<Node> node);
std::function<Placement(std::shared_ptr<Node>)> m_placement_policy;
};
}
}
...@@ -30,6 +30,12 @@ using namespace std; ...@@ -30,6 +30,12 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
ngraph::pass::Manager::Manager() ngraph::pass::Manager::Manager()
: m_to_set_is_output(true)
{
}
ngraph::pass::Manager::Manager(bool to_set_is_output)
: m_to_set_is_output(to_set_is_output)
{ {
} }
...@@ -50,6 +56,8 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) ...@@ -50,6 +56,8 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
set<shared_ptr<Function>> tfs(begin(fs), end(fs)); set<shared_ptr<Function>> tfs(begin(fs), end(fs));
get_state().set_functions(tfs); get_state().set_functions(tfs);
if (m_to_set_is_output)
{
for (shared_ptr<Function> f : get_state().get_functions()) for (shared_ptr<Function> f : get_state().get_functions())
{ {
for (size_t i = 0; i < f->get_output_size(); ++i) for (size_t i = 0; i < f->get_output_size(); ++i)
...@@ -57,6 +65,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) ...@@ -57,6 +65,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
f->get_output_op(i)->set_is_output(); f->get_output_op(i)->set_is_output();
} }
} }
}
for (shared_ptr<PassBase> pass : m_pass_list) for (shared_ptr<PassBase> pass : m_pass_list)
{ {
......
...@@ -36,6 +36,7 @@ class ngraph::pass::Manager ...@@ -36,6 +36,7 @@ class ngraph::pass::Manager
{ {
public: public:
Manager(); Manager();
Manager(bool to_set_is_output);
~Manager(); ~Manager();
void initialize_default_passes(); void initialize_default_passes();
...@@ -56,4 +57,5 @@ public: ...@@ -56,4 +57,5 @@ public:
private: private:
std::vector<std::shared_ptr<PassBase>> m_pass_list; std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state; ManagerState m_state;
bool m_to_set_is_output;
}; };
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/placement.hpp"
std::string ngraph::placement_to_string(Placement placement)
{
switch (placement)
{
case Placement::DEFAULT: return "DEFAULT";
case Placement::INTERPRETER: return "INTERPRETER";
case Placement::CPU: return "CPU";
case Placement::GPU: return "GPU";
case Placement::ARGON: return "ARGON";
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <string>
namespace ngraph
{
enum class Placement
{
DEFAULT,
INTERPRETER,
CPU,
GPU,
ARGON,
};
std::string placement_to_string(Placement placement);
}
...@@ -796,6 +796,11 @@ static shared_ptr<ngraph::Function> ...@@ -796,6 +796,11 @@ static shared_ptr<ngraph::Function>
throw runtime_error(ss.str()); throw runtime_error(ss.str());
} }
node_map[node_name] = node; node_map[node_name] = node;
// Typically, it could be unsafe to change the name of a node since it may break nameing
// uniqueness. However, it could sometimes be helpful to use the original name from
// the serialization for debugging.
// node->set_name(node_name);
} }
std::vector<std::shared_ptr<Node>> result; std::vector<std::shared_ptr<Node>> result;
......
...@@ -39,6 +39,7 @@ set (SRC ...@@ -39,6 +39,7 @@ set (SRC
input_output_assign.cpp input_output_assign.cpp
main.cpp main.cpp
op.cpp op.cpp
graph_partition.cpp
pass_liveness.cpp pass_liveness.cpp
pass_manager.cpp pass_manager.cpp
pass_memory_layout.cpp pass_memory_layout.cpp
...@@ -192,4 +193,3 @@ add_custom_target(check ...@@ -192,4 +193,3 @@ add_custom_target(check
style-check style-check
unit-test-check unit-test-check
) )
...@@ -497,6 +497,8 @@ TEST(${BACKEND_NAME}, backwards_broadcast1) ...@@ -497,6 +497,8 @@ TEST(${BACKEND_NAME}, backwards_broadcast1)
TEST(${BACKEND_NAME}, backwards_concat_vector) TEST(${BACKEND_NAME}, backwards_concat_vector)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -522,6 +524,8 @@ TEST(${BACKEND_NAME}, backwards_concat_vector) ...@@ -522,6 +524,8 @@ TEST(${BACKEND_NAME}, backwards_concat_vector)
TEST(${BACKEND_NAME}, backwards_concat_axis_0) TEST(${BACKEND_NAME}, backwards_concat_axis_0)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -547,6 +551,8 @@ TEST(${BACKEND_NAME}, backwards_concat_axis_0) ...@@ -547,6 +551,8 @@ TEST(${BACKEND_NAME}, backwards_concat_axis_0)
TEST(${BACKEND_NAME}, backwards_concat_axis_1) TEST(${BACKEND_NAME}, backwards_concat_axis_1)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -572,6 +578,8 @@ TEST(${BACKEND_NAME}, backwards_concat_axis_1) ...@@ -572,6 +578,8 @@ TEST(${BACKEND_NAME}, backwards_concat_axis_1)
TEST(${BACKEND_NAME}, backwards_ceiling) TEST(${BACKEND_NAME}, backwards_ceiling)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -632,6 +640,8 @@ TEST(${BACKEND_NAME}, backwards_cos) ...@@ -632,6 +640,8 @@ TEST(${BACKEND_NAME}, backwards_cos)
TEST(${BACKEND_NAME}, backwards_cosh) TEST(${BACKEND_NAME}, backwards_cosh)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -654,6 +664,8 @@ TEST(${BACKEND_NAME}, backwards_cosh) ...@@ -654,6 +664,8 @@ TEST(${BACKEND_NAME}, backwards_cosh)
TEST(${BACKEND_NAME}, backwards_divide) TEST(${BACKEND_NAME}, backwards_divide)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -851,6 +863,8 @@ TEST(${BACKEND_NAME}, backwards_exp) ...@@ -851,6 +863,8 @@ TEST(${BACKEND_NAME}, backwards_exp)
TEST(${BACKEND_NAME}, backwards_floor) TEST(${BACKEND_NAME}, backwards_floor)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1000,6 +1014,8 @@ TEST(${BACKEND_NAME}, backwards_parameter) ...@@ -1000,6 +1014,8 @@ TEST(${BACKEND_NAME}, backwards_parameter)
TEST(${BACKEND_NAME}, backwards_power) TEST(${BACKEND_NAME}, backwards_power)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1068,6 +1084,8 @@ TEST(${BACKEND_NAME}, backwards_relu) ...@@ -1068,6 +1084,8 @@ TEST(${BACKEND_NAME}, backwards_relu)
TEST(${BACKEND_NAME}, backwards_replace_slice) TEST(${BACKEND_NAME}, backwards_replace_slice)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1113,6 +1131,8 @@ TEST(${BACKEND_NAME}, backwards_reshape) ...@@ -1113,6 +1131,8 @@ TEST(${BACKEND_NAME}, backwards_reshape)
TEST(${BACKEND_NAME}, backwards_select) TEST(${BACKEND_NAME}, backwards_select)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1147,6 +1167,8 @@ TEST(${BACKEND_NAME}, backwards_select) ...@@ -1147,6 +1167,8 @@ TEST(${BACKEND_NAME}, backwards_select)
TEST(${BACKEND_NAME}, backwards_select_nested) TEST(${BACKEND_NAME}, backwards_select_nested)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1181,6 +1203,8 @@ TEST(${BACKEND_NAME}, backwards_select_nested) ...@@ -1181,6 +1203,8 @@ TEST(${BACKEND_NAME}, backwards_select_nested)
TEST(${BACKEND_NAME}, backwards_sign) TEST(${BACKEND_NAME}, backwards_sign)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1235,6 +1259,8 @@ TEST(${BACKEND_NAME}, backwards_sin) ...@@ -1235,6 +1259,8 @@ TEST(${BACKEND_NAME}, backwards_sin)
TEST(${BACKEND_NAME}, backwards_sinh) TEST(${BACKEND_NAME}, backwards_sinh)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1257,9 +1283,10 @@ TEST(${BACKEND_NAME}, backwards_sinh) ...@@ -1257,9 +1283,10 @@ TEST(${BACKEND_NAME}, backwards_sinh)
TEST(${BACKEND_NAME}, backwards_slice) TEST(${BACKEND_NAME}, backwards_slice)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
test::Uniform<float> rng(-10.0f, 10.0f); test::Uniform<float> rng(-10.0f, 10.0f);
Shape shape{5, 5}; Shape shape{5, 5};
auto make_graph = [shape]() { auto make_graph = [shape]() {
...@@ -1394,6 +1421,8 @@ TEST(${BACKEND_NAME}, backwards_sum_m2v_1) ...@@ -1394,6 +1421,8 @@ TEST(${BACKEND_NAME}, backwards_sum_m2v_1)
TEST(${BACKEND_NAME}, backwards_tan) TEST(${BACKEND_NAME}, backwards_tan)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
...@@ -1429,6 +1458,8 @@ TEST(${BACKEND_NAME}, backwards_tan) ...@@ -1429,6 +1458,8 @@ TEST(${BACKEND_NAME}, backwards_tan)
TEST(${BACKEND_NAME}, backwards_tanh) TEST(${BACKEND_NAME}, backwards_tanh)
{ {
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}"); SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("ARGON", "${BACKEND_NAME}");
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
......
This diff is collapsed.
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
#include <memory> #include <memory>
......
This diff is collapsed.
This diff is collapsed.
...@@ -213,11 +213,11 @@ def emit_test(t,f): ...@@ -213,11 +213,11 @@ def emit_test(t,f):
TEST (${BACKEND_NAME}, %s) TEST (${BACKEND_NAME}, %s)
{ {
Shape shape_a{%s}; Shape shape_a{%s};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{%s}; Shape shape_b{%s};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
Shape shape_r{%s}; Shape shape_r{%s};
auto make_graph = [A, B] { auto make_graph = [shape_a, shape_b] {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = make_shared<op::Parameter>(element::f32, shape_b);
return make_shared<Function>(make_shared<op::Convolution>(A, B, return make_shared<Function>(make_shared<op::Convolution>(A, B,
Strides{%s}, // move_strides Strides{%s}, // move_strides
Strides{%s}, // filter_dilation Strides{%s}, // filter_dilation
......
...@@ -69,7 +69,9 @@ namespace ngraph ...@@ -69,7 +69,9 @@ namespace ngraph
} }
if (a->get_shape() != b->get_shape()) if (a->get_shape() != b->get_shape())
{
return false; return false;
}
return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol); return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
} }
......
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/log.hpp"
#include "ngraph/runtime/manager.hpp"
#include "ngraph/types/element_type.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
#include "util/autodiff/backprop_derivative.hpp" #include "util/autodiff/backprop_derivative.hpp"
#include "util/autodiff/numeric_derivative.hpp" #include "util/autodiff/numeric_derivative.hpp"
#include "util/test_tools.hpp"
template <typename T> template <typename T>
bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Manager>& manager, bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Manager>& manager,
...@@ -27,15 +31,48 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Manager>& m ...@@ -27,15 +31,48 @@ bool autodiff_numeric_compare(const std::shared_ptr<ngraph::runtime::Manager>& m
T atol) T atol)
{ {
T delta = static_cast<T>(0.001); T delta = static_cast<T>(0.001);
// Use INTERPRETER to compute numerical derivatives
auto interpreter_manager = ngraph::runtime::Manager::get("INTERPRETER");
auto interpreter_backend = interpreter_manager->allocate_backend();
auto f = make_graph(); auto f = make_graph();
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
for (auto arg : args)
{
auto interpreter_arg = interpreter_backend->make_primary_tensor_view(
arg->get_tensor().get_element_type(), arg->get_shape());
// TODO: copy_data should not require T. Quick fix here for bool used in `Select`
if (arg->get_tensor().get_element_type() == ngraph::element::boolean)
{
copy_data(interpreter_arg, read_vector<char>(arg));
}
else
{
copy_data(interpreter_arg, read_vector<T>(arg));
}
interpreter_args.push_back(interpreter_arg);
}
auto results_num = ngraph::autodiff::numeric_derivative<T>( auto results_num = ngraph::autodiff::numeric_derivative<T>(
manager, backend, f, args, delta, f->get_parameters()); interpreter_manager, interpreter_backend, f, interpreter_args, delta, f->get_parameters());
// Use the backend being tested to compute symbolic derivatives
auto g = make_graph(); auto g = make_graph();
auto results_sym = auto results_sym =
ngraph::autodiff::backprop_derivative<T>(manager, backend, g, args, g->get_parameters()); ngraph::autodiff::backprop_derivative<T>(manager, backend, g, args, g->get_parameters());
return ngraph::test::all_close(results_num, results_sym, rtol, atol); // Cast to HostTensorView for comparision
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_results_sym;
for (auto result : results_sym)
{
auto interpreter_result = interpreter_backend->make_primary_tensor_view(
ngraph::element::from<T>(), result->get_shape());
copy_data(interpreter_result, read_vector<T>(result));
interpreter_results_sym.push_back(interpreter_result);
}
return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
} }
template <typename T> template <typename T>
...@@ -48,6 +85,7 @@ bool autodiff_numeric_compare_selective( ...@@ -48,6 +85,7 @@ bool autodiff_numeric_compare_selective(
T atol, T atol,
const std::vector<bool>& indep_param_mask) const std::vector<bool>& indep_param_mask)
{ {
// Use INTERPRETER to compute numerical derivatives
std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params; std::vector<std::shared_ptr<ngraph::op::Parameter>> f_indep_params;
auto f = make_graph(); auto f = make_graph();
...@@ -62,9 +100,30 @@ bool autodiff_numeric_compare_selective( ...@@ -62,9 +100,30 @@ bool autodiff_numeric_compare_selective(
i++; i++;
} }
auto results_num = auto interpreter_manager = ngraph::runtime::Manager::get("INTERPRETER");
ngraph::autodiff::numeric_derivative<T>(manager, backend, f, args, .001f, f_indep_params); auto interpreter_backend = interpreter_manager->allocate_backend();
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_args;
for (auto arg : args)
{
auto interpreter_arg = interpreter_backend->make_primary_tensor_view(
arg->get_tensor().get_element_type(), arg->get_shape());
// TODO: copy_data should not require T. Quick fix here for bool used in `Select`
if (arg->get_tensor().get_element_type() == ngraph::element::boolean)
{
copy_data(interpreter_arg, read_vector<char>(arg));
}
else
{
copy_data(interpreter_arg, read_vector<T>(arg));
}
interpreter_args.push_back(interpreter_arg);
}
auto results_num = ngraph::autodiff::numeric_derivative<T>(
interpreter_manager, interpreter_backend, f, interpreter_args, .001f, f_indep_params);
// Use the backend being tested to compute symbolic derivatives
std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params; std::vector<std::shared_ptr<ngraph::op::Parameter>> g_indep_params;
auto g = make_graph(); auto g = make_graph();
...@@ -82,5 +141,15 @@ bool autodiff_numeric_compare_selective( ...@@ -82,5 +141,15 @@ bool autodiff_numeric_compare_selective(
auto results_sym = auto results_sym =
ngraph::autodiff::backprop_derivative<T>(manager, backend, g, args, g_indep_params); ngraph::autodiff::backprop_derivative<T>(manager, backend, g, args, g_indep_params);
return ngraph::test::all_close(results_num, results_sym, rtol, atol); // Cast to HostTensorView for comparision
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> interpreter_results_sym;
for (auto result : results_sym)
{
auto interpreter_result = interpreter_backend->make_primary_tensor_view(
ngraph::element::from<T>(), result->get_shape());
copy_data(interpreter_result, read_vector<T>(result));
interpreter_results_sym.push_back(interpreter_result);
}
return ngraph::test::all_close(results_num, interpreter_results_sym, rtol, atol);
} }
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
...@@ -32,6 +33,16 @@ ...@@ -32,6 +33,16 @@
return; \ return; \
} }
#define ONLY_ENABLE_TEST_FOR(backend_to_enable, current_backend) \
if (backend_to_enable != current_backend) \
{ \
return; \
} \
else \
{ \
NGRAPH_INFO << "Enabled test for " << current_backend; \
}
namespace ngraph namespace ngraph
{ {
class Node; class Node;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment