Unverified Commit 69a2d4aa authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Add copy_with_new_args primitive for subgraph cloning (#225)

parent 0513ad96
# Copyright 2017 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
NGRAPH_DIST_DIR = ${HOME}/ngraph_dist
CXXFLAGS += -std=c++11
CPPFLAGS += -I $(NGRAPH_DIST_DIR)
LDFLAGS = -L $(NGRAPH_DIST_DIR)
OBJ = main.o
%.o: %.cpp $(DEPS)
$(CXX) -c -o $@ $< $(CXXFLAGS) $(CPPFLAGS)
ngraph-test: $(OBJ)
$(CXX) -o $@ $(OBJ) $(LDFLAGS) -lngraph
.PHONY: clean
clean:
rm -f $(OBJ) ngraph-test
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <stdio.h>
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/dot.hpp"
using namespace std;
using namespace ngraph;
int main(int argc, char** argv)
{
printf( "Building graph\n" );
// Function with 4 parameters
auto arg0 = op::parameter(element::Float::type, {7, 3});
auto arg1 = op::parameter(element::Float::type, {3});
auto arg2 = op::parameter(element::Float::type, {32, 7});
auto arg3 = op::parameter(element::Float::type, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
auto result = cluster_0->result();
printf( "Finished\n" );
}
\ No newline at end of file
......@@ -117,8 +117,12 @@ namespace ngraph
std::shared_ptr<Node> backprop_node(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& c);
/// Returns the shape if this node has tensor type, othetwise error.
/// Returns the shape if this node has tensor type, otherwise an ngraph-error is thrown.
const Shape& get_shape() const { return m_value_type->get_shape(); }
const element::Type& get_element_type() const { return m_value_type->get_element_type(); }
virtual std::shared_ptr<Node>
copy_with_new_args(const std::vector<std::shared_ptr<Node>>& new_args) const = 0;
protected:
Nodes m_arguments;
std::shared_ptr<const ValueType> m_value_type;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -50,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Abs>(new_args.at(0));
}
virtual std::string description() const override { return "Abs"; }
};
}
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -50,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Acos>(new_args.at(0));
}
virtual std::string description() const override { return "Acos"; }
};
}
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -51,6 +53,15 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Add>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Add"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -50,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Asin>(new_args.at(0));
}
virtual std::string description() const override { return "Asin"; }
};
}
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -50,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Atan>(new_args.at(0));
}
virtual std::string description() const override { return "Atan"; }
};
}
......
......@@ -74,12 +74,20 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
}
virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override;
/// \return An set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
protected:
const Shape& get_broadcast_shape() const { return m_shape; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Ceiling>(new_args.at(0));
}
virtual std::string description() const override { return "Ceiling"; }
};
}
......
......@@ -14,6 +14,8 @@
#pragma once
#include <memory>
#include "ngraph/ops/op.hpp"
namespace ngraph
......@@ -74,6 +76,12 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<Concat>(new_args, m_concatenation_axis);
}
virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override;
......
......@@ -82,12 +82,20 @@ namespace ngraph
/// \param value The value of the tensor constant.
ParameterizedConstant(
const Shape& shape,
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>>& value)
const typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>>& value)
: ConstantBase(std::make_shared<TensorViewType>(T::element_type(), shape))
, m_value(value)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 0)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<ParameterizedConstant<T>>(get_shape(), m_value);
}
virtual std::string description() const override { return "ParameterizedConstant"; }
virtual std::string get_node_id() const override
{
......@@ -103,7 +111,7 @@ namespace ngraph
}
protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
const std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
};
/// \brief A 32-bit floating-point tensor constant.
......@@ -171,6 +179,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 0)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Constant>(get_element_type(), get_shape(), m_value_strings);
}
virtual std::string description() const override { return "Constant"; }
virtual std::string get_node_id() const override
{
......
......@@ -61,8 +61,17 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Convert>(new_args.at(0), m_element_type);
}
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
const element::Type& get_convert_element_type() const { return m_element_type; }
virtual std::string description() const override { return "Convert"; }
protected:
const ngraph::element::Type& m_element_type;
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Cos>(new_args.at(0));
}
virtual std::string description() const override { return "Cos"; }
};
}
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Cosh>(new_args.at(0));
}
virtual std::string description() const override { return "Cosh"; }
};
}
......
......@@ -52,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Divide>(new_args.at(0), new_args.at(1));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -114,6 +114,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Dot>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override;
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Equal>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Equal"; }
};
}
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Exp>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Floor>(new_args.at(0));
}
virtual std::string description() const override { return "Floor"; }
};
}
......
......@@ -60,6 +60,12 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<FunctionCall>(m_function, new_args);
}
virtual std::string description() const override { return "FunctionCall"; }
virtual void propagate_types() override;
......
......@@ -60,6 +60,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<GetTupleElement>(new_args.at(0), m_n);
}
virtual void propagate_types() override;
virtual std::string description() const override { return "GetTupleElement"; }
/// \return The index of the tuple element to get.
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Greater>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Greater"; }
};
}
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "GreaterEq"; }
};
}
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Less>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Less"; }
};
}
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<LessEq>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "LessEq"; }
};
}
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Log>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Maximum>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Maximum"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Minimum>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Minimum"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -52,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Multiply>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Multiply"; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Negative>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "NotEqual"; }
};
}
......
......@@ -62,6 +62,14 @@ namespace ngraph
/// \param shape The shape of the parameter.
Parameter(const ngraph::element::Type& element_type, const Shape& shape);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 0)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Parameter>(get_value_type());
}
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
};
......
......@@ -51,6 +51,15 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Power>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Power"; }
};
}
......
......@@ -106,6 +106,15 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Reduce>(
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
}
virtual std::string description() const override { return "Reduce"; }
virtual void propagate_types() override;
......
......@@ -53,6 +53,15 @@ namespace ngraph
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Remainder>(new_args.at(0), new_args.at(1));
}
virtual std::string description() const override { return "Remainder"; }
};
}
......
......@@ -78,6 +78,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape);
}
virtual std::string description() const override { return "Reshape"; }
virtual void propagate_types() override;
......
......@@ -55,6 +55,15 @@ namespace ngraph
: Builtin(Nodes{arg0, arg1, arg2})
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 3)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
}
virtual std::string description() const override { return "Select"; }
virtual void propagate_types() override;
};
......
......@@ -52,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sign>(new_args.at(0));
}
virtual std::string description() const override { return "Sign"; }
};
}
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sin>(new_args.at(0));
}
virtual std::string description() const override { return "Sin"; }
};
}
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sinh>(new_args.at(0));
}
virtual std::string description() const override { return "Sinh"; }
};
}
......
......@@ -88,6 +88,15 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Slice>(
new_args.at(0), m_lower_bounds, m_upper_bounds, m_step);
}
virtual std::string description() const override { return "Slice"; }
virtual void propagate_types() override;
......
......@@ -52,6 +52,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Subtract>(new_args.at(0), new_args.at(1));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -93,6 +93,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Sum>(new_args.at(0), m_reduction_axes);
}
virtual std::string description() const override { return "Sum"; }
virtual void propagate_types() override;
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Tan>(new_args.at(0));
}
virtual std::string description() const override { return "Tan"; }
};
}
......
......@@ -50,6 +50,14 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Tanh>(new_args.at(0));
}
virtual std::string description() const override { return "Tanh"; }
};
}
......
......@@ -50,6 +50,12 @@ namespace ngraph
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<Tuple>(new_args);
}
virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override;
};
......
......@@ -177,6 +177,8 @@ namespace ngraph
const_iterator end() const { return m_elements.end(); }
vtype get_vector() { return m_elements; }
const vtype get_vector() const { return m_elements; }
operator const vtype() const { return m_elements; }
operator vtype() { return m_elements; }
bool operator==(const NDArrayBase<T>& other) const
{
return m_shape == other.m_shape && m_elements == other.m_elements;
......
......@@ -70,6 +70,11 @@ const Shape& TupleType::get_shape() const
throw ngraph_error("get_shape() called on Tuple");
}
const element::Type& TupleType::get_element_type() const
{
throw ngraph_error("get_element_type() called on Tuple");
}
std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
{
out << "ValueType()";
......
......@@ -43,6 +43,7 @@ namespace ngraph
virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const = 0;
virtual const Shape& get_shape() const = 0;
virtual const element::Type& get_element_type() const = 0;
friend std::ostream& operator<<(std::ostream&, const ValueType&);
};
......@@ -59,7 +60,7 @@ namespace ngraph
{
}
const element::Type& get_element_type() const { return m_element_type; }
virtual const element::Type& get_element_type() const override { return m_element_type; }
virtual const Shape& get_shape() const override { return m_shape; }
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(
......@@ -93,6 +94,8 @@ namespace ngraph
return m_element_types;
}
virtual const element::Type& get_element_type() const override;
virtual bool operator==(const ValueType& that) const override;
virtual void collect_tensor_views(
std::vector<std::shared_ptr<const TensorViewType>>& views) const override;
......
......@@ -23,6 +23,7 @@ include_directories(
set (SRC
autodiff.cpp
copy.cpp
build_graph.cpp
eigen.cpp
input_output_assign.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
template <typename OP>
bool check_unary()
{
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto node = make_shared<OP>(arg0);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<OP>(new_node);
return (nullptr != new_node) && (new_args == new_node->get_arguments());
}
template <typename OP>
bool check_binary()
{
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto node = make_shared<OP>(arg0, arg1);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<OP>(new_node);
return (nullptr != new_node) && (new_args == new_node->get_arguments());
}
TEST(copy, abs)
{
ASSERT_TRUE(check_unary<op::Abs>());
}
TEST(copy, acos)
{
ASSERT_TRUE(check_unary<op::Acos>());
}
TEST(copy, add)
{
ASSERT_TRUE(check_binary<op::Add>());
}
TEST(copy, asin)
{
ASSERT_TRUE(check_unary<op::Asin>());
}
TEST(copy, atan)
{
ASSERT_TRUE(check_unary<op::Atan>());
}
TEST(copy, broadcast)
{
Shape shape1{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape1);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape1)};
Shape shape{4, 1, 3};
AxisSet axes{0, 2};
auto node = make_shared<op::Broadcast>(arg0, shape, axes);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Broadcast>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(shape == node_cast->get_broadcast_shape());
ASSERT_TRUE(axes == node_cast->get_broadcast_axes());
}
TEST(copy, ceiling)
{
ASSERT_TRUE(check_unary<op::Ceiling>());
}
TEST(copy, concat)
{
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
size_t axis = 1;
auto node = make_shared<op::Concat>(Nodes{arg0, arg1}, axis);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Concat>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(node_cast->get_concatenation_axis() == axis);
}
TEST(copy, parameterized_constant)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
// Create some tensors for input/output
auto c = backend->make_parameterized_tensor_view<element::Float32>(
runtime::NDArray<float, 2>({{1, 2}, {3, 4}}));
Shape shape{2, 2};
auto node = make_shared<op::ParameterizedConstant<element::Float32>>(shape, c);
auto new_node = node->copy_with_new_args(Nodes{});
auto node_cast = dynamic_pointer_cast<op::ParameterizedConstant<element::Float32>>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(Nodes{} == new_node->get_arguments());
ASSERT_TRUE(node_cast->get_value() == c);
ASSERT_TRUE(node_cast->get_shape() == shape);
}
TEST(copy, constant)
{
Shape shape{};
vector<string> c{"2.4"};
auto& et = element::Float32::element_type();
auto node = make_shared<op::Constant>(et, shape, c);
auto new_node = node->copy_with_new_args(Nodes{});
auto node_cast = dynamic_pointer_cast<op::Constant>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(Nodes{} == new_node->get_arguments());
ASSERT_TRUE(node_cast->get_value_strings() == c);
ASSERT_TRUE(node_cast->get_shape() == shape);
ASSERT_TRUE(node_cast->get_element_type() == et);
}
TEST(copy, convert)
{
Shape shape;
auto& et = element::Float64::element_type();
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto node = make_shared<op::Convert>(arg0, et);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Convert>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(et == node_cast->get_convert_element_type());
}
TEST(copy, cos)
{
ASSERT_TRUE(check_unary<op::Cos>());
}
TEST(copy, cosh)
{
ASSERT_TRUE(check_unary<op::Cosh>());
}
TEST(copy, divide)
{
ASSERT_TRUE(check_binary<op::Divide>());
}
TEST(copy, dot)
{
ASSERT_TRUE(check_binary<op::Dot>());
}
TEST(copy, equal)
{
ASSERT_TRUE(check_binary<op::Equal>());
}
TEST(copy, exp)
{
ASSERT_TRUE(check_unary<op::Exp>());
}
TEST(copy, floor)
{
ASSERT_TRUE(check_unary<op::Floor>());
}
TEST(copy, FunctionCall)
{
Shape shape{1};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt, op::Parameters{A, B, C});
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto node = make_shared<op::FunctionCall>(f, Nodes{arg0, arg1});
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::FunctionCall>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(node_cast->get_function() == f);
}
TEST(copy, GetTupleElement)
{
Shape shape{1};
size_t n = 0;
auto tuple_type = make_shared<TupleType>(vector<shared_ptr<const ValueType>>{
make_shared<TensorViewType>(element::Float32::element_type(), shape)});
auto arg0 = make_shared<op::Parameter>(tuple_type);
std::vector<std::shared_ptr<Node>> new_args{make_shared<op::Parameter>(tuple_type)};
auto node = make_shared<op::GetTupleElement>(arg0, n);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::GetTupleElement>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(node_cast->get_n() == n);
}
TEST(copy, greater_eq)
{
ASSERT_TRUE(check_binary<op::GreaterEq>());
}
TEST(copy, greater)
{
ASSERT_TRUE(check_binary<op::Greater>());
}
TEST(copy, less_eq)
{
ASSERT_TRUE(check_binary<op::LessEq>());
}
TEST(copy, less)
{
ASSERT_TRUE(check_binary<op::Less>());
}
TEST(copy, log)
{
ASSERT_TRUE(check_unary<op::Log>());
}
TEST(copy, maximum)
{
ASSERT_TRUE(check_binary<op::Maximum>());
}
TEST(copy, minimum)
{
ASSERT_TRUE(check_binary<op::Minimum>());
}
TEST(copy, multiply)
{
ASSERT_TRUE(check_binary<op::Multiply>());
}
TEST(copy, negative)
{
ASSERT_TRUE(check_unary<op::Negative>());
}
TEST(copy, not_equal)
{
ASSERT_TRUE(check_binary<op::NotEqual>());
}
TEST(copy, parameter)
{
Shape shape{1};
auto node = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto new_node = node->copy_with_new_args({});
auto node_cast = dynamic_pointer_cast<op::Parameter>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_node->get_arguments().size() == 0);
ASSERT_TRUE(node->get_value_type() == new_node->get_value_type());
}
TEST(copy, power)
{
ASSERT_TRUE(check_binary<op::Power>());
}
TEST(copy, reduce)
{
Shape scalar_shape{};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), scalar_shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), scalar_shape);
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), scalar_shape);
auto f = make_shared<Function>(A + B, rt, op::Parameters{A, B});
Shape shape{4, 3};
AxisSet axes{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg_init = make_shared<op::Parameter>(element::Float32::element_type(), scalar_shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), scalar_shape)};
auto node = make_shared<op::Reduce>(arg0, arg_init, f, axes);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Reduce>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(f == node_cast->get_reduction_function());
ASSERT_TRUE(axes == node_cast->get_reduction_axes());
}
TEST(copy, remainder)
{
ASSERT_TRUE(check_binary<op::Remainder>());
}
TEST(copy, reshape)
{
Shape shape_in{2, 3, 4};
AxisVector axes{0, 1, 2};
Shape shape_out{6, 4};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape_in);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape_in)};
auto node = make_shared<op::Reshape>(arg0, axes, shape_out);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Reshape>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(axes == node_cast->get_input_order());
ASSERT_TRUE(shape_out == node_cast->get_output_shape());
}
TEST(copy, select)
{
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::Bool::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Bool::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto node = make_shared<op::Select>(arg0, arg1, arg2);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Select>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
}
TEST(copy, sign)
{
ASSERT_TRUE(check_unary<op::Sign>());
}
TEST(copy, sin)
{
ASSERT_TRUE(check_unary<op::Sin>());
}
TEST(copy, sinh)
{
ASSERT_TRUE(check_unary<op::Sinh>());
}
TEST(copy, slice)
{
Shape shape_in{2, 3, 4};
Coordinate lower{0, 0, 0};
Coordinate upper{2, 3, 4};
Coordinate step{1, 1, 1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape_in);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape_in)};
auto node = make_shared<op::Slice>(arg0, lower, upper, step);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Slice>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(lower == node_cast->get_lower_bounds());
ASSERT_TRUE(upper == node_cast->get_upper_bounds());
ASSERT_TRUE(step == node_cast->get_step());
}
TEST(copy, subtract)
{
ASSERT_TRUE(check_binary<op::Subtract>());
}
TEST(copy, sum)
{
Shape shape{4, 3};
AxisSet axes{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto node = make_shared<op::Sum>(arg0, axes);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Sum>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(axes == node_cast->get_reduction_axes());
}
TEST(copy, tan)
{
ASSERT_TRUE(check_unary<op::Tan>());
}
TEST(copy, tanh)
{
ASSERT_TRUE(check_unary<op::Tanh>());
}
TEST(copy, tuple)
{
Shape shape{1};
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto node = make_shared<op::Tuple>(Nodes{arg0, arg1});
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Tuple>(new_node);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment