Commit 5cc7bd9a authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 'cyphers/autodiff' of…

Merge branch 'cyphers/autodiff' of https://github.com/NervanaSystems/private-ngraph-cpp into cyphers/autodiff
parents 5bd169b8 7d39a99b
...@@ -12,6 +12,8 @@ TODO ...@@ -12,6 +12,8 @@ TODO
## Steps ## Steps
_If you are developing ngraph on macOS (officially unsupported) please see the section "macOS Development Prerequisites" below._
`libngraph` is build in the customary manner for a CMake-based project: `libngraph` is build in the customary manner for a CMake-based project:
1. Create a build directory outside of source directory tree. 1. Create a build directory outside of source directory tree.
...@@ -23,6 +25,17 @@ TODO ...@@ -23,6 +25,17 @@ TODO
6. _(Optional, requires `doxygen`)_ Run `make doc`. 6. _(Optional, requires `doxygen`)_ Run `make doc`.
* This will build API documentation in the directory `doc` inside the build directory. * This will build API documentation in the directory `doc` inside the build directory.
## macOS Development Prerequisites
The repository includes two scripts (`maint/check-code-format.sh` and `maint/apply-code-format.sh`) that are used respectively to check adherence to `libngraph` code formatting conventions, and to automatically reformat code according to those conventions. These scripts require the command `clang-format-3.9` to be in your `PATH`. Run the following commands (you will need to adjust them if you are not using `bash`).
```
$ brew install llvm@3.9
$ mkdir -p $HOME/bin
$ ln -s /usr/local/opt/llvm@3.9/bin/clang-format $HOME/bin/clang-format-3.9
$ echo 'export PATH=$HOME/bin:$PATH' >> $HOME/.bash_profile
```
# Testing `libngraph` # Testing `libngraph`
`libngraph` uses the GTest framework for unit tests. CMake automatically downloads a `libngraph` uses the GTest framework for unit tests. CMake automatically downloads a
......
...@@ -39,26 +39,26 @@ Function::Function(const std::shared_ptr<Node>& result, ...@@ -39,26 +39,26 @@ Function::Function(const std::shared_ptr<Node>& result,
parameter->assign_function(this, i++); parameter->assign_function(this, i++);
} }
traverse_nodes(result, [&](Node* node) { m_ops.push_back(node); }); traverse_nodes(result, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
} }
void Function::set_ordered_ops(const std::list<Node*>& ordered_ops) void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
{ {
m_ordered_ops = ordered_ops; m_ordered_ops = ordered_ops;
m_ordered_ops_valid = true; m_ordered_ops_valid = true;
} }
std::list<Node*>& Function::get_ops() std::list<shared_ptr<Node>>& Function::get_ops()
{ {
return m_ops; return m_ops;
} }
const std::list<Node*>& Function::get_ops() const const std::list<shared_ptr<Node>>& Function::get_ops() const
{ {
return m_ops; return m_ops;
} }
std::list<Node*>& Function::get_ordered_ops() std::list<shared_ptr<Node>>& Function::get_ordered_ops()
{ {
if (!m_ordered_ops_valid) if (!m_ordered_ops_valid)
{ {
...@@ -67,7 +67,7 @@ std::list<Node*>& Function::get_ordered_ops() ...@@ -67,7 +67,7 @@ std::list<Node*>& Function::get_ordered_ops()
return m_ordered_ops; return m_ordered_ops;
} }
const std::list<Node*>& Function::get_ordered_ops() const const std::list<shared_ptr<Node>>& Function::get_ordered_ops() const
{ {
if (!m_ordered_ops_valid) if (!m_ordered_ops_valid)
{ {
......
...@@ -47,11 +47,11 @@ namespace ngraph ...@@ -47,11 +47,11 @@ namespace ngraph
const std::shared_ptr<const ValueType> get_result_type() const { return m_result_type; } const std::shared_ptr<const ValueType> get_result_type() const { return m_result_type; }
std::string get_name() const; std::string get_name() const;
void set_name(const std::string& name); void set_name(const std::string& name);
std::list<Node*>& get_ops(); std::list<std::shared_ptr<Node>>& get_ops();
const std::list<Node*>& get_ops() const; const std::list<std::shared_ptr<Node>>& get_ops() const;
std::list<Node*>& get_ordered_ops(); std::list<std::shared_ptr<Node>>& get_ordered_ops();
const std::list<Node*>& get_ordered_ops() const; const std::list<std::shared_ptr<Node>>& get_ordered_ops() const;
void set_ordered_ops(const std::list<Node*>&); void set_ordered_ops(const std::list<std::shared_ptr<Node>>&);
void set_ordered_ops_valid() { m_ordered_ops_valid = true; } void set_ordered_ops_valid() { m_ordered_ops_valid = true; }
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; } void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&); friend std::ostream& operator<<(std::ostream&, const Function&);
...@@ -62,8 +62,8 @@ namespace ngraph ...@@ -62,8 +62,8 @@ namespace ngraph
std::string m_name; std::string m_name;
std::shared_ptr<const ValueType> m_result_type; std::shared_ptr<const ValueType> m_result_type;
bool m_ordered_ops_valid; bool m_ordered_ops_valid;
std::list<Node*> m_ordered_ops; std::list<std::shared_ptr<Node>> m_ordered_ops;
std::list<Node*> m_ops; std::list<std::shared_ptr<Node>> m_ops;
private: private:
Function(const Function&) = delete; Function(const Function&) = delete;
......
...@@ -16,10 +16,35 @@ ...@@ -16,10 +16,35 @@
using namespace ngraph::op; using namespace ngraph::op;
void ScalarConstantBase::propagate_types() void ConstantBase::propagate_types()
{ {
} }
void TensorConstantBase::propagate_types() template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
{ {
auto result = ET::read(value_strings);
}
void Constant::propagate_types()
{
// No actual type propagation is done here; however, we check the number of value strings and
// also call check_value_strings just to make sure the result will be parseable at compile
// time. (It will throw an exception if not.)
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(get_value_type());
if (nullptr == tvt)
{
throw ngraph_error("Constant does not have tensor view type");
}
auto shape = tvt->get_shape();
if (ngraph::shape_size(shape) != m_value_strings.size())
{
throw ngraph_error("Constant does not have the expected number of literals");
}
auto& et = tvt->get_element_type();
FUNCTION_ON_ELEMENT_TYPE(
et, "Constant has unhandled element type", check_value_strings, m_value_strings);
} }
...@@ -24,11 +24,11 @@ namespace ngraph ...@@ -24,11 +24,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
// Defines methods to all constant scalars // Defines methods to all constants
class ScalarConstantBase : public Node class ConstantBase : public Node
{ {
protected: protected:
ScalarConstantBase(const std::shared_ptr<TensorViewType>& type) ConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type) : Node({}, type)
{ {
} }
...@@ -36,10 +36,9 @@ namespace ngraph ...@@ -36,10 +36,9 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
// Implement a constant scalar for each element type. // Implement a constant tensor for each element type.
// The static make method takes a
template <typename T> template <typename T>
class ScalarConstant : public ScalarConstantBase class ParameterizedConstant : public ConstantBase
{ {
public: public:
// The ngraph element type // The ngraph element type
...@@ -47,13 +46,15 @@ namespace ngraph ...@@ -47,13 +46,15 @@ namespace ngraph
// The C++ type that holds the element type // The C++ type that holds the element type
using type = typename T::type; using type = typename T::type;
ScalarConstant(typename T::type value) ParameterizedConstant(
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{})) const Shape& shape,
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>>& value)
: ConstantBase(std::make_shared<TensorViewType>(T::element_type(), shape))
, m_value(value) , m_value(value)
{ {
} }
virtual std::string description() const override { return "ScalarConstant"; } virtual std::string description() const override { return "ParameterizedConstant"; }
virtual std::string get_node_id() const override virtual std::string get_node_id() const override
{ {
std::stringstream ss; std::stringstream ss;
...@@ -61,48 +62,41 @@ namespace ngraph ...@@ -61,48 +62,41 @@ namespace ngraph
return ss.str(); return ss.str();
} }
type get_value() const { return m_value; } typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const
protected:
typename T::type m_value;
};
using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
// Defines methods to all constant tensors
class TensorConstantBase : public Node
{
protected:
TensorConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{ {
return m_value;
} }
virtual void propagate_types() override; protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
}; };
// Implement a constant tensor for each element type. using Float32Constant = ParameterizedConstant<element::Float32>;
template <typename T> using Int8Constant = ParameterizedConstant<element::Int8>;
class TensorConstant : public TensorConstantBase using Int32Constant = ParameterizedConstant<element::Int32>;
using Int64Constant = ParameterizedConstant<element::Int64>;
using UInt8Constant = ParameterizedConstant<element::UInt8>;
using UInt32Constant = ParameterizedConstant<element::UInt32>;
using UInt64Constant = ParameterizedConstant<element::UInt64>;
class Constant : public ConstantBase
{ {
public: public:
// The ngraph element type Constant(const element::Type& et,
using element_type = T; const Shape& shape,
// The C++ type that holds the element type const std::vector<std::string>& value_strings)
using type = typename T::type; : ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
}
TensorConstant(const Shape& shape) Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: TensorConstantBase(std::make_shared<TensorViewType>(T::element_type(), shape)) : ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value(ngraph::runtime::make_tensor<T>(shape)) , m_value_strings(ngraph::shape_size(shape), value_string)
{ {
} }
virtual std::string description() const override { return "TensorConstant"; } virtual std::string description() const override { return "Constant"; }
virtual std::string get_node_id() const override virtual std::string get_node_id() const override
{ {
std::stringstream ss; std::stringstream ss;
...@@ -110,21 +104,11 @@ namespace ngraph ...@@ -110,21 +104,11 @@ namespace ngraph
return ss.str(); return ss.str();
} }
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const const std::vector<std::string>& get_value_strings() const { return m_value_strings; }
{ virtual void propagate_types() override;
return m_value;
}
protected: protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value; const std::vector<std::string> m_value_strings;
}; };
using Float32TensorConstant = TensorConstant<element::Float32>;
using Int8TensorConstant = TensorConstant<element::Int8>;
using Int32TensorConstant = TensorConstant<element::Int32>;
using Int64TensorConstant = TensorConstant<element::Int64>;
using UInt8TensorConstant = TensorConstant<element::UInt8>;
using UInt32TensorConstant = TensorConstant<element::UInt32>;
using UInt64TensorConstant = TensorConstant<element::UInt64>;
} }
} }
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
/// @param function The function to be called /// @param function The function to be called
/// @param args The function arguments /// @param args The function arguments
/// ///
FunctionCall(const std::shared_ptr<Function>& function, FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args) const std::vector<std::shared_ptr<Node>>& args)
: Builtin(args) : Builtin(args)
, m_function(function) , m_function(function)
......
...@@ -25,15 +25,15 @@ ...@@ -25,15 +25,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::AssignTensors::run_on_call_graph(list<Node*>& nodes) bool pass::AssignTensors::run_on_call_graph(list<std::shared_ptr<Node>>& nodes)
{ {
for (Node* node : nodes) for (shared_ptr<Node> node : nodes)
{ {
try try
{ {
// We need to set the nodes is_output state prior to call assign_tensors // We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors. // so that the output state can be passes to the constructed tensors.
if (node == get_state().get_functions().at(0)->get_result().get()) if (node == get_state().get_functions().at(0)->get_result())
{ {
node->set_is_output(); node->set_is_output();
} }
......
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
class ngraph::pass::AssignTensors : public CallGraphPass class ngraph::pass::AssignTensors : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_graph(std::list<Node*>& nodes) override; virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override;
private: private:
}; };
...@@ -24,22 +24,22 @@ using namespace std; ...@@ -24,22 +24,22 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::pass; using namespace ngraph::pass;
bool CollectFunctions::run_on_function(ngraph::Function* func) bool CollectFunctions::run_on_function(shared_ptr<ngraph::Function> func)
{ {
set<Function*> functions; set<shared_ptr<ngraph::Function>> functions;
deque<Function*> stack; deque<shared_ptr<ngraph::Function>> stack;
stack.push_back(func); stack.push_back(func);
while (stack.empty() == false) while (stack.empty() == false)
{ {
Function* f = stack.front(); shared_ptr<ngraph::Function> f = stack.front();
stack.pop_front(); stack.pop_front();
functions.insert(f); functions.insert(f);
traverse_nodes(f->get_result(), [&](Node* node) { traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
op::FunctionCall* fc = dynamic_cast<op::FunctionCall*>(node); shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc) if (fc)
{ {
stack.push_back(fc->get_function().get()); stack.push_back(fc->get_function());
} }
}); });
} }
......
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
class ngraph::pass::CollectFunctions : public FunctionPass class ngraph::pass::CollectFunctions : public FunctionPass
{ {
public: public:
bool run_on_function(ngraph::Function*) override; bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private: private:
}; };
...@@ -28,14 +28,14 @@ pass::DumpSorted::DumpSorted(const string& output_file) ...@@ -28,14 +28,14 @@ pass::DumpSorted::DumpSorted(const string& output_file)
{ {
} }
bool pass::DumpSorted::run_on_module(vector<Function*>& functions) bool pass::DumpSorted::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{ {
ofstream out{m_output_file}; ofstream out{m_output_file};
if (out) if (out)
{ {
for (Function* f : functions) for (shared_ptr<Function> f : functions)
{ {
for (const Node* node : f->get_ordered_ops()) for (const shared_ptr<Node>& node : f->get_ordered_ops())
{ {
out << node->get_name() << "("; out << node->get_name() << "(";
vector<string> inputs; vector<string> inputs;
......
...@@ -31,7 +31,7 @@ class ngraph::pass::DumpSorted : public ModulePass ...@@ -31,7 +31,7 @@ class ngraph::pass::DumpSorted : public ModulePass
public: public:
DumpSorted(const std::string& output_file); DumpSorted(const std::string& output_file);
virtual bool run_on_module(std::vector<Function*>&) override; virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private: private:
const std::string m_output_file; const std::string m_output_file;
......
...@@ -28,13 +28,13 @@ using namespace std; ...@@ -28,13 +28,13 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
bool pass::Liveness::run_on_call_graph(list<Node*>& ops) bool pass::Liveness::run_on_call_graph(list<shared_ptr<Node>>& ops)
{ {
unordered_set<Tensor*> currently_live; unordered_set<Tensor*> currently_live;
for (auto it = ops.rbegin(); it != ops.rend(); it++) for (auto it = ops.rbegin(); it != ops.rend(); it++)
{ {
Node* node = *it; shared_ptr<Node> node = *it;
node->liveness_live_list.clear(); node->liveness_live_list.clear();
node->liveness_new_list.clear(); node->liveness_new_list.clear();
node->liveness_free_list.clear(); node->liveness_free_list.clear();
...@@ -91,7 +91,7 @@ bool pass::Liveness::run_on_call_graph(list<Node*>& ops) ...@@ -91,7 +91,7 @@ bool pass::Liveness::run_on_call_graph(list<Node*>& ops)
// Add outputs to live_list and remove from free_list // Add outputs to live_list and remove from free_list
unordered_set<Tensor*> outputs; unordered_set<Tensor*> outputs;
unordered_set<Tensor*> seen; unordered_set<Tensor*> seen;
for (Node* node : ops) for (shared_ptr<Node> node : ops)
{ {
for (Tensor* tensor : node->liveness_live_list) for (Tensor* tensor : node->liveness_live_list)
{ {
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
class ngraph::pass::Liveness : public CallGraphPass class ngraph::pass::Liveness : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_graph(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private: private:
bool is_temporary(const descriptor::Tensor&); bool is_temporary(const descriptor::Tensor&);
......
...@@ -38,12 +38,7 @@ void ngraph::pass::Manager::initialize_default_passes() ...@@ -38,12 +38,7 @@ void ngraph::pass::Manager::initialize_default_passes()
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
{ {
run_passes(func.get()); vector<shared_ptr<Function>> fs = {func};
}
void ngraph::pass::Manager::run_passes(Function* func)
{
vector<Function*> fs = {func};
get_state().set_functions(fs); get_state().set_functions(fs);
for (shared_ptr<PassBase> pass : m_pass_list) for (shared_ptr<PassBase> pass : m_pass_list)
...@@ -59,16 +54,16 @@ void ngraph::pass::Manager::run_passes(Function* func) ...@@ -59,16 +54,16 @@ void ngraph::pass::Manager::run_passes(Function* func)
} }
else if (function_pass) else if (function_pass)
{ {
for (Function* f : fs) for (shared_ptr<Function> f : fs)
{ {
function_pass->run_on_function(f); function_pass->run_on_function(f);
} }
} }
else if (node_pass) else if (node_pass)
{ {
for (Function* f : fs) for (shared_ptr<Function> f : fs)
{ {
for (Node* n : f->get_ops()) for (shared_ptr<Node> n : f->get_ops())
{ {
node_pass->run_on_node(n); node_pass->run_on_node(n);
} }
...@@ -76,7 +71,7 @@ void ngraph::pass::Manager::run_passes(Function* func) ...@@ -76,7 +71,7 @@ void ngraph::pass::Manager::run_passes(Function* func)
} }
else if (call_graph_pass) else if (call_graph_pass)
{ {
for (Function* f : fs) for (shared_ptr<Function> f : fs)
{ {
call_graph_pass->run_on_call_graph(f->get_ordered_ops()); call_graph_pass->run_on_call_graph(f->get_ordered_ops());
} }
......
...@@ -47,7 +47,6 @@ public: ...@@ -47,7 +47,6 @@ public:
m_pass_list.push_back(pass_base); m_pass_list.push_back(pass_base);
} }
void run_passes(Function*);
void run_passes(std::shared_ptr<Function>); void run_passes(std::shared_ptr<Function>);
ManagerState& get_state(); ManagerState& get_state();
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
vector<Function*>& ngraph::pass::ManagerState::get_functions() vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions()
{ {
return m_function_list; return m_function_list;
} }
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
class ngraph::pass::ManagerState class ngraph::pass::ManagerState
{ {
public: public:
std::vector<Function*>& get_functions(); std::vector<std::shared_ptr<Function>>& get_functions();
template <typename T> template <typename T>
void set_functions(const T& collection) void set_functions(const T& collection)
...@@ -44,5 +44,5 @@ public: ...@@ -44,5 +44,5 @@ public:
private: private:
size_t m_temporary_pool_size = 0; size_t m_temporary_pool_size = 0;
std::vector<Function*> m_function_list; std::vector<std::shared_ptr<Function>> m_function_list;
}; };
...@@ -26,10 +26,10 @@ using namespace std; ...@@ -26,10 +26,10 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
bool pass::MemoryLayout::run_on_call_graph(std::list<Node*>& node_list) bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& node_list)
{ {
MemoryManager mm; MemoryManager mm;
for (const Node* node : node_list) for (shared_ptr<Node> node : node_list)
{ {
for (Tensor* tensor : node->liveness_new_list) for (Tensor* tensor : node->liveness_new_list)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
class ngraph::pass::MemoryLayout : public CallGraphPass class ngraph::pass::MemoryLayout : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_graph(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private: private:
}; };
......
...@@ -32,13 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename) ...@@ -32,13 +32,13 @@ pass::MemoryVisualize::MemoryVisualize(const string& filename)
{ {
} }
bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions) bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{ {
ofstream file(m_filename); ofstream file(m_filename);
{ {
for (const Function* f : functions) for (shared_ptr<Function> f : functions)
{ {
const list<Node*> nodes = f->get_ordered_ops(); list<shared_ptr<Node>> nodes = f->get_ordered_ops();
file << "<!DOCTYPE html>\n<html>\n"; file << "<!DOCTYPE html>\n<html>\n";
file << "<head>\n"; file << "<head>\n";
file << " <style>\n"; file << " <style>\n";
...@@ -62,7 +62,7 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions) ...@@ -62,7 +62,7 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
file << "<body>\n"; file << "<body>\n";
unordered_set<descriptor::Tensor*> tensors; unordered_set<descriptor::Tensor*> tensors;
size_t temp_max_size = 0; size_t temp_max_size = 0;
for (Node* node : nodes) for (shared_ptr<Node> node : nodes)
{ {
tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end()); tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end());
} }
...@@ -96,11 +96,11 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions) ...@@ -96,11 +96,11 @@ bool pass::MemoryVisualize::run_on_module(vector<Function*>& functions)
return false; return false;
} }
const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes) shared_ptr<Node> pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes)
{ {
const Node* largest_op = nullptr; shared_ptr<Node> largest_op = nullptr;
size_t largest_size = 0; size_t largest_size = 0;
for (const Node* exop : nodes) for (shared_ptr<Node> exop : nodes)
{ {
size_t size = 0; size_t size = 0;
for (const Tensor* tensor : exop->liveness_live_list) for (const Tensor* tensor : exop->liveness_live_list)
...@@ -116,9 +116,9 @@ const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes) ...@@ -116,9 +116,9 @@ const Node* pass::MemoryVisualize::find_largest_op(const list<Node*>& nodes)
return largest_op; return largest_op;
} }
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& nodes) void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ptr<Node>>& nodes)
{ {
const Node* largest_op = find_largest_op(nodes); shared_ptr<Node> largest_op = find_largest_op(nodes);
if (largest_op) if (largest_op)
{ {
...@@ -130,7 +130,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& ...@@ -130,7 +130,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
unordered_map<const Tensor*, size_t> age_list; unordered_map<const Tensor*, size_t> age_list;
vector<const Tensor*> tensor_set; vector<const Tensor*> tensor_set;
unordered_map<const Tensor*, const Node*> generator_op; unordered_map<const Tensor*, shared_ptr<Node>> generator_op;
file << "<table>\n"; file << "<table>\n";
file << " <tr>"; file << " <tr>";
file << "<th align=\"left\">tensor</th>"; file << "<th align=\"left\">tensor</th>";
...@@ -139,7 +139,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& ...@@ -139,7 +139,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
file << "<th align=\"right\">generator weight</th>"; file << "<th align=\"right\">generator weight</th>";
file << "</tr>\n"; file << "</tr>\n";
size_t i = 0; size_t i = 0;
for (const Node* exop : nodes) for (shared_ptr<Node> exop : nodes)
{ {
for (const Tensor* tensor : exop->liveness_new_list) for (const Tensor* tensor : exop->liveness_new_list)
{ {
...@@ -179,7 +179,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>& ...@@ -179,7 +179,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<Node*>&
} }
} }
void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nodes) void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<Node>>& nodes)
{ {
size_t stroke_width = 14; size_t stroke_width = 14;
size_t text_offset = 4; size_t text_offset = 4;
...@@ -188,7 +188,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod ...@@ -188,7 +188,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
size_t scale = width - offset; size_t scale = width - offset;
size_t line_spacing = stroke_width * 1.5; size_t line_spacing = stroke_width * 1.5;
size_t line_count = 0; size_t line_count = 0;
for (const Node* node : nodes) for (shared_ptr<Node> node : nodes)
{ {
(void)node; (void)node;
line_count += 1; line_count += 1;
...@@ -198,7 +198,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod ...@@ -198,7 +198,7 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
file << "<svg viewBox=\"0 0 " << width << " " << height << "\">\n"; file << "<svg viewBox=\"0 0 " << width << " " << height << "\">\n";
size_t y = 0; size_t y = 0;
for (const Node* node : nodes) for (shared_ptr<Node> node : nodes)
{ {
float usage = float(MemoryVisualize::memory_usage(node)); float usage = float(MemoryVisualize::memory_usage(node));
float footprint = float(MemoryVisualize::memory_footprint(node)); float footprint = float(MemoryVisualize::memory_footprint(node));
...@@ -220,14 +220,14 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod ...@@ -220,14 +220,14 @@ void pass::MemoryVisualize::draw_histogram(ostream& file, const list<Node*>& nod
file << "</svg>\n"; file << "</svg>\n";
} }
void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>& nodes) void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<shared_ptr<Node>>& nodes)
{ {
file << "<table>\n"; file << "<table>\n";
file << " <tr>"; file << " <tr>";
file << "<th align=\"left\">op</th>"; file << "<th align=\"left\">op</th>";
file << "<th align=\"right\">influence</th>"; file << "<th align=\"right\">influence</th>";
file << "</tr>\n"; file << "</tr>\n";
for (const Node* exop : nodes) for (shared_ptr<Node> exop : nodes)
{ {
int weight = compute_op_weight(exop); int weight = compute_op_weight(exop);
file << " <tr>"; file << " <tr>";
...@@ -237,7 +237,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>& ...@@ -237,7 +237,7 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<Node*>&
} }
} }
int pass::MemoryVisualize::compute_op_weight(const Node* exop) int pass::MemoryVisualize::compute_op_weight(const shared_ptr<Node> exop)
{ {
int mass = 0; int mass = 0;
// for input_decl in exop.input_decls: // for input_decl in exop.input_decls:
...@@ -265,17 +265,17 @@ int pass::MemoryVisualize::compute_op_weight(const Node* exop) ...@@ -265,17 +265,17 @@ int pass::MemoryVisualize::compute_op_weight(const Node* exop)
return mass; return mass;
} }
size_t pass::MemoryVisualize::memory_usage(const Node* node) size_t pass::MemoryVisualize::memory_usage(shared_ptr<Node> node)
{ {
return 0; return 0;
} }
size_t pass::MemoryVisualize::memory_footprint(const Node* node) size_t pass::MemoryVisualize::memory_footprint(shared_ptr<Node> node)
{ {
return 0; return 0;
} }
size_t pass::MemoryVisualize::memory_footprint(const std::list<Node*>& nodes) size_t pass::MemoryVisualize::memory_footprint(const std::list<shared_ptr<Node>>& nodes)
{ {
return 0; return 0;
} }
...@@ -32,18 +32,18 @@ class ngraph::pass::MemoryVisualize : public ModulePass ...@@ -32,18 +32,18 @@ class ngraph::pass::MemoryVisualize : public ModulePass
{ {
public: public:
MemoryVisualize(const std::string& filename); MemoryVisualize(const std::string& filename);
virtual bool run_on_module(std::vector<Function*>&) override; virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private: private:
const Node* find_largest_op(const std::list<Node*>& nodes); std::shared_ptr<Node> find_largest_op(const std::list<std::shared_ptr<Node>>& nodes);
void draw_tensor_weight(std::ostream& file, const std::list<Node*>& nodes); void draw_tensor_weight(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_histogram(std::ostream& file, const std::list<Node*>& nodes); void draw_histogram(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_op_influence(std::ostream& file, const std::list<Node*>& nodes); void draw_op_influence(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
int compute_op_weight(const Node* exop); int compute_op_weight(std::shared_ptr<Node> exop);
static size_t memory_usage(const Node*); static size_t memory_usage(std::shared_ptr<Node>);
static size_t memory_footprint(const Node*); static size_t memory_footprint(std::shared_ptr<Node>);
static size_t memory_footprint(const std::list<Node*>&); static size_t memory_footprint(const std::list<std::shared_ptr<Node>>&);
const std::string m_filename; const std::string m_filename;
}; };
...@@ -53,26 +53,26 @@ class ngraph::pass::ModulePass : public PassBase ...@@ -53,26 +53,26 @@ class ngraph::pass::ModulePass : public PassBase
{ {
public: public:
virtual ~ModulePass() {} virtual ~ModulePass() {}
virtual bool run_on_module(std::vector<ngraph::Function*>&) = 0; virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) = 0;
}; };
class ngraph::pass::FunctionPass : public PassBase class ngraph::pass::FunctionPass : public PassBase
{ {
public: public:
virtual ~FunctionPass() {} virtual ~FunctionPass() {}
virtual bool run_on_function(ngraph::Function*) = 0; virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
}; };
class ngraph::pass::NodePass : public PassBase class ngraph::pass::NodePass : public PassBase
{ {
public: public:
virtual ~NodePass() {} virtual ~NodePass() {}
virtual bool run_on_node(ngraph::Node*) = 0; virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
}; };
class ngraph::pass::CallGraphPass : public PassBase class ngraph::pass::CallGraphPass : public PassBase
{ {
public: public:
virtual ~CallGraphPass() {} virtual ~CallGraphPass() {}
virtual bool run_on_call_graph(std::list<ngraph::Node*>&) = 0; virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) = 0;
}; };
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::PropagateTypes::run_on_call_graph(list<Node*>& nodes) bool pass::PropagateTypes::run_on_call_graph(list<shared_ptr<Node>>& nodes)
{ {
for (Node* node : nodes) for (shared_ptr<Node> node : nodes)
{ {
try try
{ {
......
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
class ngraph::pass::PropagateTypes : public CallGraphPass class ngraph::pass::PropagateTypes : public CallGraphPass
{ {
public: public:
virtual bool run_on_call_graph(std::list<Node*>&) override; virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private: private:
}; };
...@@ -25,24 +25,26 @@ ...@@ -25,24 +25,26 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
bool ngraph::pass::TopologicalSort::run_on_function(ngraph::Function* func) bool ngraph::pass::TopologicalSort::run_on_function(shared_ptr<ngraph::Function> func)
{ {
list<Node*> result_list; list<shared_ptr<Node>> result_list;
deque<Node*> independent_nodes; deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count; unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(func->get_result(), [&](Node* node) { traverse_nodes(func->get_result(), [&](shared_ptr<Node> node) {
node_depencency_count[node] = node->get_arguments().size(); node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0) if (node->get_arguments().size() == 0)
{ {
independent_nodes.push_back(node); independent_nodes.push_back(node.get());
} }
}); });
while (independent_nodes.size() > 0) while (independent_nodes.size() > 0)
{ {
auto independent_node = independent_nodes.front(); auto independent_node = independent_nodes.front();
result_list.push_back(independent_node); result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front(); independent_nodes.pop_front();
for (auto user : independent_node->users()) for (auto user : independent_node->users())
......
...@@ -31,5 +31,5 @@ class ngraph::pass::TopologicalSort : public FunctionPass ...@@ -31,5 +31,5 @@ class ngraph::pass::TopologicalSort : public FunctionPass
{ {
public: public:
TopologicalSort() {} TopologicalSort() {}
bool run_on_function(ngraph::Function*) override; bool run_on_function(std::shared_ptr<ngraph::Function>) override;
}; };
...@@ -23,15 +23,15 @@ ...@@ -23,15 +23,15 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
bool pass::VisualizeTree::run_on_module(vector<ngraph::Function*>& functions) bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{ {
for (Function* f : functions) for (shared_ptr<Function> f : functions)
{ {
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(f->get_result(), [&](Node* node) { traverse_nodes(f->get_result(), [&](shared_ptr<Node> node) {
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << add_attributes(arg.get()); m_ss << add_attributes(arg);
m_ss << add_attributes(node); m_ss << add_attributes(node);
m_ss << " " << arg->get_name() << " -> " << node->get_name(); m_ss << " " << arg->get_name() << " -> " << node->get_name();
m_ss << ";\n"; m_ss << ";\n";
...@@ -49,7 +49,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name) ...@@ -49,7 +49,7 @@ pass::VisualizeTree::VisualizeTree(const string& file_name)
{ {
} }
std::string pass::VisualizeTree::add_attributes(const Node* node) std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
{ {
string rc; string rc;
if (!contains(m_nodes_with_attributes, node)) if (!contains(m_nodes_with_attributes, node))
...@@ -60,7 +60,7 @@ std::string pass::VisualizeTree::add_attributes(const Node* node) ...@@ -60,7 +60,7 @@ std::string pass::VisualizeTree::add_attributes(const Node* node)
return rc; return rc;
} }
std::string pass::VisualizeTree::get_attributes(const Node* node) std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
{ {
stringstream ss; stringstream ss;
if (node->is_parameter()) if (node->is_parameter())
......
...@@ -32,14 +32,14 @@ class ngraph::pass::VisualizeTree : public ModulePass ...@@ -32,14 +32,14 @@ class ngraph::pass::VisualizeTree : public ModulePass
{ {
public: public:
VisualizeTree(const std::string& file_name); VisualizeTree(const std::string& file_name);
bool run_on_module(std::vector<ngraph::Function*>&) override; bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private: private:
std::string add_attributes(const Node* node); std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(const Node* node); std::string get_attributes(std::shared_ptr<Node> node);
void render() const; void render() const;
std::stringstream m_ss; std::stringstream m_ss;
std::string m_name; std::string m_name;
std::set<const Node*> m_nodes_with_attributes; std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
}; };
...@@ -315,15 +315,10 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -315,15 +315,10 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
#define REGISTER_CONSTANT_INSTRUCTIONS(T) \ #define REGISTER_CONSTANT_INSTRUCTIONS(T) \
{ \ { \
REGISTER_INSTRUCTION( \ REGISTER_INSTRUCTION( \
op::ScalarConstant<T>, \ op::ParameterizedConstant<T>, \
eigen::ConstantInstruction<T>, \
std::vector<T::type>{dynamic_cast<const op::ScalarConstant<T>*>(n)->get_value()}, \
out[0]); \
REGISTER_INSTRUCTION( \
op::TensorConstant<T>, \
eigen::ConstantInstruction<T>, \ eigen::ConstantInstruction<T>, \
std::vector<T::type>{ \ std::vector<T::type>{ \
dynamic_cast<const op::TensorConstant<T>*>(n)->get_value()->get_vector()}, \ dynamic_cast<const op::ParameterizedConstant<T>*>(n)->get_value()->get_vector()}, \
out[0]); \ out[0]); \
} }
...@@ -371,6 +366,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -371,6 +366,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_NUMERIC_BINOP(op::Multiply, eigen::MultiplyInstruction); REGISTER_NUMERIC_BINOP(op::Multiply, eigen::MultiplyInstruction);
REGISTER_NUMERIC_BINOP(op::Subtract, eigen::SubtractInstruction); REGISTER_NUMERIC_BINOP(op::Subtract, eigen::SubtractInstruction);
REGISTER_TO_OP_MAP(op::Constant)
{
auto c = static_cast<const op::Constant*>(n);
auto c_tensor_type = dynamic_pointer_cast<const TensorViewType>(c->get_value_type());
assert(nullptr != c_tensor_type);
auto& c_element_type = c_tensor_type->get_element_type();
auto c_value_strings = c->get_value_strings();
#define M_REGISTER_POLYMORPHIC_CONSTANT(ET) \
ef->get_instructions()->push_back( \
make_shared<eigen::ConstantInstruction<ET>>(ET::read(c_value_strings), out[0]));
DO_ON_ELEMENT_TYPE(c_element_type,
"Constant has unhandled element type",
M_REGISTER_POLYMORPHIC_CONSTANT);
};
REGISTER_POLYMORPHIC_BINOP(op::Equal, eigen::EqualInstruction); REGISTER_POLYMORPHIC_BINOP(op::Equal, eigen::EqualInstruction);
REGISTER_POLYMORPHIC_BINOP(op::NotEqual, eigen::NotEqualInstruction); REGISTER_POLYMORPHIC_BINOP(op::NotEqual, eigen::NotEqualInstruction);
...@@ -949,7 +961,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -949,7 +961,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Turn this into a pass // Turn this into a pass
// Assign layouts // Assign layouts
// For now, just make everyone row-major. // For now, just make everyone row-major.
for (const Node* node : m_function->get_ordered_ops()) for (shared_ptr<Node> node : m_function->get_ordered_ops())
{ {
for (const descriptor::Output& output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
...@@ -986,7 +998,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -986,7 +998,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
m_n_outputs = tensor_index.size() - m_n_inputs; m_n_outputs = tensor_index.size() - m_n_inputs;
// All remaining tensor views // All remaining tensor views
for (const Node* node : m_function->get_ordered_ops()) for (shared_ptr<Node> node : m_function->get_ordered_ops())
{ {
for (const descriptor::Output& output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
...@@ -1002,9 +1014,11 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -1002,9 +1014,11 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Now we build the eigen-VM instructions // Now we build the eigen-VM instructions
auto op_map = get_op_map(); auto op_map = get_op_map();
for (const Node* node : m_function->get_ordered_ops()) for (shared_ptr<Node> node : m_function->get_ordered_ops())
{ {
auto handler_it = op_map.find(type_index(typeid(*node))); auto& n = *node; // Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto handler_it = op_map.find(type_index(typeid(n)));
if (handler_it == op_map.end()) if (handler_it == op_map.end())
{ {
throw ngraph_error("Unhandled op during code generation"); throw ngraph_error("Unhandled op during code generation");
...@@ -1022,7 +1036,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -1022,7 +1036,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
out.push_back({tensor_index.at(tv), tv}); out.push_back({tensor_index.at(tv), tv});
} }
handler_it->second(node, this, function_map, in, out); handler_it->second(node.get(), this, function_map, in, out);
} }
m_instructions->push_back(make_shared<eigen::ReturnInstruction>()); m_instructions->push_back(make_shared<eigen::ReturnInstruction>());
m_is_compiled = true; m_is_compiled = true;
......
...@@ -116,6 +116,35 @@ namespace ngraph ...@@ -116,6 +116,35 @@ namespace ngraph
{ {
return std::make_shared<runtime::ParameterizedTensorView<TraitedType<T>>>(shape); return std::make_shared<runtime::ParameterizedTensorView<TraitedType<T>>>(shape);
} }
static T read(const std::string& s)
{
T result;
std::stringstream ss;
ss << s;
ss >> result;
// Check that (1) parsing succeeded and (2) the entire string was used.
if (ss.fail() || ss.rdbuf()->in_avail() != 0)
{
throw ngraph_error("Could not parse literal");
}
return result;
}
static std::vector<T> read(const std::vector<std::string>& ss)
{
std::vector<T> result;
for (auto s : ss)
{
result.push_back(read(s));
}
return result;
}
}; };
NGRAPH_DEFINE_TRAITED_TYPE_NAME(char) NGRAPH_DEFINE_TRAITED_TYPE_NAME(char)
...@@ -143,3 +172,23 @@ namespace ngraph ...@@ -143,3 +172,23 @@ namespace ngraph
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
//
// Utility macro for dispatching an element type-templated function at runtime.
//
// clang-format off
// Sorry, but you really don't want to see what clang-format does to this thing. :)
#define FUNCTION_ON_ELEMENT_TYPE(et, err_msg, f, ...) \
( \
((et) == element::Bool::element_type()) ? (f<element::Bool>(__VA_ARGS__)) : \
((et) == element::Float32::element_type()) ? (f<element::Float32>(__VA_ARGS__)) : \
((et) == element::Int8::element_type()) ? (f<element::Int8>(__VA_ARGS__)) : \
((et) == element::Int32::element_type()) ? (f<element::Int32>(__VA_ARGS__)) : \
((et) == element::Int64::element_type()) ? (f<element::Int64>(__VA_ARGS__)) : \
((et) == element::UInt8::element_type()) ? (f<element::UInt8>(__VA_ARGS__)) : \
((et) == element::UInt32::element_type()) ? (f<element::UInt32>(__VA_ARGS__)) : \
((et) == element::UInt64::element_type()) ? (f<element::UInt64>(__VA_ARGS__)) : \
(throw ngraph_error(err_msg)) \
)
// clang-format on
...@@ -137,15 +137,16 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list) ...@@ -137,15 +137,16 @@ size_t ngraph::hash_combine(const std::vector<size_t>& list)
return seed; return seed;
} }
void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::function<void(Node*)> f) void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p,
std::function<void(shared_ptr<Node>)> f)
{ {
std::unordered_set<Node*> instances_seen; std::unordered_set<shared_ptr<Node>> instances_seen;
deque<Node*> stack; deque<shared_ptr<Node>> stack;
stack.push_front(p.get()); stack.push_front(p);
while (stack.size() > 0) while (stack.size() > 0)
{ {
Node* n = stack.front(); shared_ptr<Node> n = stack.front();
if (instances_seen.find(n) == instances_seen.end()) if (instances_seen.find(n) == instances_seen.end())
{ {
instances_seen.insert(n); instances_seen.insert(n);
...@@ -154,7 +155,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::functio ...@@ -154,7 +155,7 @@ void ngraph::traverse_nodes(const std::shared_ptr<ngraph::Node>& p, std::functio
stack.pop_front(); stack.pop_front();
for (auto arg : n->get_arguments()) for (auto arg : n->get_arguments())
{ {
stack.push_front(arg.get()); stack.push_front(arg);
} }
} }
} }
...@@ -163,7 +164,7 @@ void ngraph::free_nodes(shared_ptr<Node> p) ...@@ -163,7 +164,7 @@ void ngraph::free_nodes(shared_ptr<Node> p)
{ {
std::deque<Node*> sorted_list; std::deque<Node*> sorted_list;
traverse_nodes(p, [&](Node* n) { sorted_list.push_front(n); }); traverse_nodes(p, [&](shared_ptr<Node> n) { sorted_list.push_front(n.get()); });
for (Node* n : sorted_list) for (Node* n : sorted_list)
{ {
......
...@@ -195,7 +195,8 @@ namespace ngraph ...@@ -195,7 +195,8 @@ namespace ngraph
return a * b; return a * b;
} }
void traverse_nodes(const std::shared_ptr<Node>& p, std::function<void(Node*)> f); void traverse_nodes(const std::shared_ptr<Node>& p,
std::function<void(std::shared_ptr<Node>)> f);
void free_nodes(std::shared_ptr<Node>); void free_nodes(std::shared_ptr<Node>);
} // end namespace ngraph } // end namespace ngraph
...@@ -80,23 +80,22 @@ TEST(build_graph, node_comparison) ...@@ -80,23 +80,22 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal) TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatConstant::make(3.0);
auto float0 = make_shared<op::Float32ScalarConstant>(3.0); auto float_t = ngraph::runtime::make_tensor<element::Float32>(Shape{});
(*float_t) = std::vector<float>{3.0};
auto float0 = make_shared<op::Float32Constant>(Shape{}, float_t);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0); ASSERT_EQ(float0->get_value()->get_vector(), std::vector<float>{3.0});
ASSERT_EQ(*float0->get_value_type(), *float_scalar_type); ASSERT_EQ(*float0->get_value_type(), *float_scalar_type);
auto d = make_shared<op::Dot>(float0, float0); auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0); ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int auto int32_t = ngraph::runtime::make_tensor<element::Int32>(Shape{});
auto float1 = make_shared<op::Float32ScalarConstant>(3); (*int32_t) = std::vector<int>{3};
ASSERT_EQ(float1->get_value(), 3); auto int32_0 = make_shared<op::Int32Constant>(Shape{}, int32_t);
ASSERT_EQ(*float1->get_value_type(), *float_scalar_type);
auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3); ASSERT_EQ(int32_0->get_value()->get_vector(), std::vector<int>{3});
ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type); ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type);
} }
...@@ -104,8 +103,9 @@ TEST(build_graph, literal) ...@@ -104,8 +103,9 @@ TEST(build_graph, literal)
TEST(build_graph, tensor) TEST(build_graph, tensor)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatConstant::make(3.0);
auto float0 = make_shared<op::Float32TensorConstant>(Shape{2, 3}); auto float_t = ngraph::runtime::make_tensor<element::Float32>(Shape{2, 3});
auto float0 = make_shared<op::Float32Constant>(Shape{2, 3}, float_t);
auto float_tensor_type = auto float_tensor_type =
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3}); make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3});
ASSERT_EQ(*float0->get_value_type(), *float_tensor_type); ASSERT_EQ(*float0->get_value_type(), *float_tensor_type);
...@@ -113,7 +113,8 @@ TEST(build_graph, tensor) ...@@ -113,7 +113,8 @@ TEST(build_graph, tensor)
ASSERT_EQ(d->get_arguments().at(0), float0); ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(1), float0);
auto int32_0 = make_shared<op::Int32TensorConstant>(Shape{3, 5}); auto int32_t = ngraph::runtime::make_tensor<element::Int32>(Shape{3, 5});
auto int32_0 = make_shared<op::Int32Constant>(Shape{3, 5}, int32_t);
auto int32_tensor_type = auto int32_tensor_type =
make_shared<TensorViewType>(element::Int32::element_type(), Shape{3, 5}); make_shared<TensorViewType>(element::Int32::element_type(), Shape{3, 5});
ASSERT_EQ(*int32_0->get_value_type(), *int32_tensor_type); ASSERT_EQ(*int32_0->get_value_type(), *int32_tensor_type);
......
...@@ -997,7 +997,9 @@ TEST(execute, subtract) ...@@ -997,7 +997,9 @@ TEST(execute, subtract)
TEST(execute, scalar_constant) TEST(execute, scalar_constant)
{ {
auto shape = Shape{}; auto shape = Shape{};
auto A = make_shared<op::ScalarConstant<element::Float32>>(-3.0f); auto t = ngraph::runtime::make_tensor<element::Float32>(shape);
(*t) = std::vector<float>{-3.0f};
auto A = make_shared<op::ParameterizedConstant<element::Float32>>(shape, t);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(A, rt, op::Parameters{}); auto f = make_shared<Function>(A, rt, op::Parameters{});
...@@ -1016,8 +1018,9 @@ TEST(execute, scalar_constant) ...@@ -1016,8 +1018,9 @@ TEST(execute, scalar_constant)
TEST(execute, tensor_constant) TEST(execute, tensor_constant)
{ {
auto shape = Shape{2, 2, 2}; auto shape = Shape{2, 2, 2};
auto A = make_shared<op::TensorConstant<element::Float32>>(shape); auto t = ngraph::runtime::make_tensor<element::Float32>(shape);
A->get_value()->get_vector() = {1, 2, 3, 4, 5, 6, 7, 8}; (*t) = std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8};
auto A = make_shared<op::ParameterizedConstant<element::Float32>>(shape, t);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(A, rt, op::Parameters{}); auto f = make_shared<Function>(A, rt, op::Parameters{});
...@@ -1036,8 +1039,9 @@ TEST(execute, tensor_constant) ...@@ -1036,8 +1039,9 @@ TEST(execute, tensor_constant)
TEST(execute, tensor_constant_with_op) TEST(execute, tensor_constant_with_op)
{ {
auto shape = Shape{2, 2, 2}; auto shape = Shape{2, 2, 2};
auto A = make_shared<op::TensorConstant<element::Float32>>(shape); auto t = ngraph::runtime::make_tensor<element::Float32>(shape);
A->get_value()->get_vector() = {-1, 2, 3, -4, 5, -6, -7, 8}; (*t) = std::vector<float>{-1, 2, 3, -4, 5, -6, -7, 8};
auto A = make_shared<op::ParameterizedConstant<element::Float32>>(shape, t);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape); auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Abs>(A), rt, op::Parameters{}); auto f = make_shared<Function>(make_shared<op::Abs>(A), rt, op::Parameters{});
...@@ -1882,7 +1886,7 @@ TEST(execute, sin) ...@@ -1882,7 +1886,7 @@ TEST(execute, sin)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return sinf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return sinf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -1908,7 +1912,7 @@ TEST(execute, cos) ...@@ -1908,7 +1912,7 @@ TEST(execute, cos)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return cosf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return cosf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -1934,7 +1938,7 @@ TEST(execute, tan) ...@@ -1934,7 +1938,7 @@ TEST(execute, tan)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return tanf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return tanf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -1959,7 +1963,7 @@ TEST(execute, asin) ...@@ -1959,7 +1963,7 @@ TEST(execute, asin)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return asinf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return asinf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -1984,7 +1988,7 @@ TEST(execute, acos) ...@@ -1984,7 +1988,7 @@ TEST(execute, acos)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return acosf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return acosf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -2009,7 +2013,7 @@ TEST(execute, atan) ...@@ -2009,7 +2013,7 @@ TEST(execute, atan)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return atanf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return atanf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -2034,7 +2038,7 @@ TEST(execute, sinh) ...@@ -2034,7 +2038,7 @@ TEST(execute, sinh)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return sinhf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return sinhf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -2059,7 +2063,7 @@ TEST(execute, cosh) ...@@ -2059,7 +2063,7 @@ TEST(execute, cosh)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return coshf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return coshf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -2084,7 +2088,7 @@ TEST(execute, tanh) ...@@ -2084,7 +2088,7 @@ TEST(execute, tanh)
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape); auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
std::transform( std::transform(
input.begin(), input.end(), input.begin(), [](float v) -> float { return tanhf(v); }); input.begin(), input.end(), input.begin(), [](float x) -> float { return tanhf(x); });
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ(input, result->get_vector()); ASSERT_EQ(input, result->get_vector());
...@@ -2184,3 +2188,89 @@ TEST(execute, slice_vector) ...@@ -2184,3 +2188,89 @@ TEST(execute, slice_vector)
(*cf)({a}, {result}); (*cf)({a}, {result});
ASSERT_EQ((vector<float>{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}), result->get_vector()); ASSERT_EQ((vector<float>{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}), result->get_vector());
} }
TEST(execute, scalar_constant_float32)
{
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto r = make_shared<op::Constant>(element::Float32::element_type(), Shape{}, "4.8");
auto f = make_shared<Function>(r, rt, op::Parameters{});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Float32>(Shape{});
(*cf)({}, {result});
ASSERT_EQ(vector<float>{std::strtof("4.8", NULL)}, result->get_vector());
}
TEST(execute, scalar_constant_int64)
{
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), Shape{});
auto r = make_shared<op::Constant>(element::Int64::element_type(), Shape{}, "2112");
auto f = make_shared<Function>(r, rt, op::Parameters{});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Int64>(Shape{});
(*cf)({}, {result});
ASSERT_EQ(vector<element::Int64::type>{std::strtol("2112", NULL, 10)}, result->get_vector());
}
TEST(execute, tensor_constant_float32)
{
auto shape = Shape{2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto r = make_shared<op::Constant>(element::Float32::element_type(),
shape,
std::vector<std::string>{"4.8", "4.7", "-5.3", "0"});
auto f = make_shared<Function>(r, rt, op::Parameters{});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({}, {result});
ASSERT_EQ((vector<float>{std::strtof("4.8", NULL),
std::strtof("4.7", NULL),
std::strtof("-5.3", NULL),
std::strtof("0", NULL)}),
result->get_vector());
}
TEST(execute, tensor_constant_int64)
{
auto shape = Shape{2, 2};
auto rt = make_shared<TensorViewType>(element::Int64::element_type(), shape);
auto r = make_shared<op::Constant>(element::Int64::element_type(),
shape,
std::vector<std::string>{"2112", "1848", "1776", "1964"});
auto f = make_shared<Function>(r, rt, op::Parameters{});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Int64>(shape);
(*cf)({}, {result});
ASSERT_EQ((vector<element::Int64::type>{std::strtol("2112", NULL, 10),
std::strtol("1848", NULL, 10),
std::strtol("1776", NULL, 10),
std::strtol("1964", NULL, 10)}),
result->get_vector());
}
...@@ -50,7 +50,7 @@ TEST(pass, liveness) ...@@ -50,7 +50,7 @@ TEST(pass, liveness)
pass_manager.register_pass<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
shared_ptr<Function> func = make_test_graph(); shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func.get()); pass_manager.run_passes(func);
auto sorted = func->get_ordered_ops(); auto sorted = func->get_ordered_ops();
// for (const Node* node : sorted) // for (const Node* node : sorted)
......
...@@ -39,7 +39,7 @@ TEST(pass_manager, add) ...@@ -39,7 +39,7 @@ TEST(pass_manager, add)
auto graph = make_test_graph(); auto graph = make_test_graph();
size_t node_count = get_node_count(graph->get_result()); size_t node_count = get_node_count(graph->get_result());
pass_manager.run_passes(graph.get()); pass_manager.run_passes(graph);
auto sorted = graph->get_ordered_ops(); auto sorted = graph->get_ordered_ops();
EXPECT_EQ(node_count, sorted.size()); EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted)); EXPECT_TRUE(validate_list(sorted));
......
...@@ -23,7 +23,7 @@ using namespace ngraph; ...@@ -23,7 +23,7 @@ using namespace ngraph;
// This function traverses the list of ops and verifies that each op's dependencies (its inputs) // This function traverses the list of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the list. That is enough to be valid // is located earlier in the list. That is enough to be valid
bool validate_list(const list<Node*>& nodes) bool validate_list(const list<shared_ptr<Node>>& nodes)
{ {
bool rc = true; bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++) for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
...@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes) ...@@ -39,7 +39,7 @@ bool validate_list(const list<Node*>& nodes)
for (; tmp != nodes.rend(); tmp++) for (; tmp != nodes.rend(); tmp++)
{ {
auto dep_tmp = *tmp; auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp); auto found = find(dependencies.begin(), dependencies.end(), dep_tmp.get());
if (found != dependencies.end()) if (found != dependencies.end())
{ {
dependencies.erase(found); dependencies.erase(found);
...@@ -82,6 +82,6 @@ shared_ptr<Function> make_test_graph() ...@@ -82,6 +82,6 @@ shared_ptr<Function> make_test_graph()
size_t get_node_count(std::shared_ptr<Node> n) size_t get_node_count(std::shared_ptr<Node> n)
{ {
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(n, [&](const Node* node) { node_count++; }); traverse_nodes(n, [&](shared_ptr<Node> node) { node_count++; });
return node_count; return node_count;
} }
...@@ -23,6 +23,6 @@ namespace ngraph ...@@ -23,6 +23,6 @@ namespace ngraph
class Function; class Function;
} }
bool validate_list(const std::list<ngraph::Node*>& nodes); bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph(); std::shared_ptr<ngraph::Function> make_test_graph();
size_t get_node_count(std::shared_ptr<ngraph::Node> n); size_t get_node_count(std::shared_ptr<ngraph::Node> n);
...@@ -126,7 +126,7 @@ TEST(benchmark, topological_sort) ...@@ -126,7 +126,7 @@ TEST(benchmark, topological_sort)
NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "topological sort took " << timer.get_milliseconds() << "ms";
size_t node_count = 0; size_t node_count = 0;
traverse_nodes(result, [&](const Node* node) { node_count++; }); traverse_nodes(result, [&](shared_ptr<Node> node) { node_count++; });
NGRAPH_INFO << "node count " << node_count; NGRAPH_INFO << "node count " << node_count;
...@@ -135,6 +135,7 @@ TEST(benchmark, topological_sort) ...@@ -135,6 +135,7 @@ TEST(benchmark, topological_sort)
timer.stop(); timer.stop();
NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms"; NGRAPH_INFO << "delete nodes took " << timer.get_milliseconds() << "ms";
} }
TEST(topological_sort, collect_functions) TEST(topological_sort, collect_functions)
{ {
// First create "f(A,B,C) = (A+B)*C". // First create "f(A,B,C) = (A+B)*C".
...@@ -174,7 +175,7 @@ TEST(topological_sort, collect_functions) ...@@ -174,7 +175,7 @@ TEST(topological_sort, collect_functions)
set<string> expected = {"f", "g", "h"}; set<string> expected = {"f", "g", "h"};
auto functions = pass_manager.get_state().get_functions(); auto functions = pass_manager.get_state().get_functions();
vector<string> fnames; vector<string> fnames;
for (Function* func : functions) for (shared_ptr<Function> func : functions)
{ {
fnames.push_back(func->get_name()); fnames.push_back(func->get_name());
} }
......
...@@ -1517,3 +1517,98 @@ TEST(type_prop, slice_deduce_matrix_upper_extra) ...@@ -1517,3 +1517,98 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, scalar_constant_deduce_float32)
{
auto c = make_shared<op::Constant>(element::Float32::element_type(), Shape{}, "208");
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, scalar_constant_deduce_bool)
{
auto c = make_shared<op::Constant>(element::Bool::element_type(), Shape{}, "1");
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{}));
}
TEST(type_prop, tensor_constant_deduce_float32)
{
auto c = make_shared<op::Constant>(element::Float32::element_type(),
Shape{2, 2},
std::vector<std::string>{"208", "208", "208", "208"});
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{2, 2}));
}
TEST(type_prop, tensor_constant_deduce_bool)
{
auto c = make_shared<op::Constant>(
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1", "1"});
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{2, 2}));
}
TEST(type_prop, tensor_constant_bad_parse)
{
auto c = make_shared<op::Constant>(element::Bool::element_type(),
Shape{2, 2},
std::vector<std::string>{"1", "grunk", "1", "1"});
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Could not parse literal"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, tensor_constant_bad_parse_float_for_int)
{
auto c = make_shared<op::Constant>(element::Int32::element_type(),
Shape{2, 2},
std::vector<std::string>{"1", "2.7", "1", "1"});
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Could not parse literal"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, tensor_constant_bad_count)
{
auto c = make_shared<op::Constant>(
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1"});
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect number of literals not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("Constant does not have the expected number of literals"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment