// ---------------------------------------------------------------------------- // Copyright 2017 Nervana Systems Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // ---------------------------------------------------------------------------- #include <sstream> #include <string> #include <vector> #include "gtest/gtest.h" #include "transformers/op_graph.hpp" using namespace ngraph; TEST(op_graph, constant) { float expected_value = 42; op_ptr x = constant(expected_value); ASSERT_NE(nullptr, x); EXPECT_EQ(true, x->is_constant()); EXPECT_EQ(false, x->is_input()); EXPECT_EQ(true, x->is_persistent()); EXPECT_EQ(false, x->is_trainable()); EXPECT_EQ(false, x->is_placeholder()); auto ato = std::dynamic_pointer_cast<AssignableTensorOp>(x); ASSERT_NE(nullptr, ato); // TODO: fix this auto ti = ato->m_value; ASSERT_NE(nullptr, ti); std::string actual_value = ti->value_string(); std::stringstream ss; ss << expected_value; std::string expected_string = ss.str(); EXPECT_STREQ(actual_value.c_str(), expected_string.c_str()); } // @pytest.fixture() // def N(): Axis N() { // return ng.make_axis(length=1) return make_axis(1); } // def test_deriv_missing_connection(N): // """ // Taking the derivative of an expression with respect to a variable not // used to compute the expression should raise an exception. // """ TEST(op_graph, deriv_missing_connection) { // x = ng.variable([N]) // auto x = variable({N()}); // y = ng.variable([N]) // z = ng.variable([N]) // with pytest.raises(ValueError): // ng.deriv(x + y, z) } // def test_one(): // # Test that the cacheing on constant one used in DerivOp works. // op = ng.variable([]) // one_0 = op.one // one_1 = op.one // assert one_0 is one_1 // def test_pad_invalid_paddings_length(N): // """ // pad should raise an exception if the paddings length is not the same as the // input dimensionality. // """ // x = ng.variable([N]) // with pytest.raises(ValueError): // ng.pad(x, [1, 0]) // def test_pad_0(N): // """ // pad with length 0 should be a nop // """ // x = ng.variable([N]) // assert ng.pad(x, [0]).axes == x.axes // def test_pad_mixed(): // """ // mix 0 padding with non-0 padding // """ // input_axes = ng.make_axes([ // ng.make_axis(1), // ng.make_axis(1) // ]) // x = ng.variable(input_axes) // pad = ng.pad(x, [0, 1]) // assert pad.axes[0] == x.axes[0] // assert pad.axes[1] != x.axes[1] // def test_slice_nop(): // """ // slicing an axis shouldn't change the name // """ // input_axes = ng.make_axes([ // ng.make_axis(1), // ng.make_axis(1) // ]) // x = ng.variable(input_axes) // s = ng.tensor_slice(x, [ // slice(None, None, None), // slice(None, None, 1), // ]) // assert s.axes[0] == x.axes[0] // assert s.axes[1] == x.axes[1] // def test_tensor_slice(): // """ // slicing a tensor should work like numpy // """ // input_axes = ng.make_axes([ // ng.make_axis(10), // ng.make_axis(20), // ng.make_axis(5) // ]) // x = ng.placeholder(axes=input_axes) // assert x[:5].axes.full_lengths == (5, 20, 5) // assert x[:, 2:7].axes.full_lengths == (10, 5, 5) // assert x[:5, :, :-1].axes.full_lengths == (5, 20, 4)