specialize_function.cpp 11.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
20
#include "ngraph/specialize_function.hpp"
21 22 23 24 25

using namespace ngraph;

// Simple case: create a function with static parameter shapes and "specialize" them to the same
// shapes.
26
TEST(specialize_function, et_shape_static)
27 28 29 30 31 32 33 34 35
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, Shape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

36 37 38 39 40 41
    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);
42 43 44 45 46 47

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of dynamic element types.
48
TEST(specialize_function, et_dynamic_shape_static)
49 50 51 52 53 54 55 56 57
{
    auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

58 59 60 61 62 63
    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);
64 65 66 67 68 69

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of rank-dynamic shapes.
70
TEST(specialize_function, et_static_shape_rank_dynamic)
71 72 73 74 75 76 77 78 79
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

80 81 82 83 84 85
    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);
86 87 88 89 90 91

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of rank-static dynamic shapes.
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
TEST(specialize_function, et_static_shape_rank_static_dynamic)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of values to a shape-dynamic parameters.
TEST(specialize_function, et_static_shape_rank_static_dynamic_subst_val)
115 116 117 118 119 120 121 122 123
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

124 125 126 127 128 129 130 131
    std::vector<int32_t> p1_subst_vals{5, 0, 3, 8, 5, 8};

    std::vector<void*> param_vals{nullptr, p1_subst_vals.data()};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);
132 133 134

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
135 136 137 138 139 140 141 142 143 144 145

    auto plus_node = std::dynamic_pointer_cast<op::Add>(g->get_results().at(0)->get_argument(0));
    ASSERT_TRUE(plus_node);
    auto convert_node = std::dynamic_pointer_cast<op::Convert>(plus_node->get_argument(1));
    ASSERT_TRUE(convert_node);
    auto const_node = std::dynamic_pointer_cast<op::Constant>(convert_node->get_argument(0));
    ASSERT_TRUE(const_node);

    ASSERT_EQ(const_node->get_output_element_type(0), element::i32);
    ASSERT_EQ(const_node->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(const_node->get_vector<int32_t>(), p1_subst_vals);
146 147 148 149 150
}

// Test specialization of rank-dynamic shapes to a case where validation will fail.
//
// (The input shapes we provide at specialization time are inconsistent.)
151
TEST(specialize_function, et_static_shape_rank_dynamic_validation_fails)
152 153 154 155 156 157 158 159 160
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

161 162
    std::vector<void*> param_vals{nullptr, nullptr};

163 164
    ASSERT_THROW(
        {
165 166 167 168
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}},
                                param_vals);
169 170 171 172 173 174 175
        },
        NodeValidationFailure);
}

// Test specialization of dynamic element types to a case where validation will fail.
//
// (The input element types we provide at specialization time are inconsistent.)
176
TEST(specialize_function, et_dynamic_shape_static_validation_fails)
177 178 179 180 181 182 183 184 185
{
    auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

186 187
    std::vector<void*> param_vals{nullptr, nullptr};

188 189
    ASSERT_THROW(
        {
190 191 192 193
            specialize_function(f,
                                {element::u32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                param_vals);
194 195 196 197 198 199 200 201
        },
        NodeValidationFailure);
}

// Test specialization of rank-static dynamic shapes, where the replacement shapes have the wrong
// rank.
//
// (Note that we are testing for a different exception class here because the failure is in
202
// specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we
203
// reconstruct the graph.)
204
TEST(specialize_function, et_static_shape_rank_static_dynamic_rank_mismatch)
205 206 207 208 209 210 211 212 213
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

214 215
    std::vector<void*> param_vals{nullptr, nullptr};

216 217
    ASSERT_THROW(
        {
218 219 220 221
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}},
                                param_vals);
222
        },
223
        CheckFailure);
224 225 226 227 228 229
}

// Test specialization of rank-static dynamic shapes, where the replacement shapes have wrong
// dimensions.
//
// (Note that we are testing for a different exception class here because the failure is in
230
// specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we
231
// reconstruct the graph.)
232
TEST(specialize_function, et_static_shape_rank_static_dynamic_dim_mismatch)
233 234 235 236 237 238 239 240 241 242
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 =
        std::make_shared<op::Parameter>(element::i32, PartialShape{1, Dimension::dynamic(), 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

243 244
    std::vector<void*> param_vals{nullptr, nullptr};

245 246
    ASSERT_THROW(
        {
247 248 249 250
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 9, 4}},
                                param_vals);
251
        },
252
        CheckFailure);
253 254 255
}

// Test for failure when we supply the wrong number of replacement element types.
256
TEST(specialize_function, et_count_wrong)
257 258 259 260 261 262 263 264 265
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

266 267
    std::vector<void*> param_vals{nullptr, nullptr};

268 269
    ASSERT_THROW(
        {
270 271 272 273
            specialize_function(f,
                                {element::f32, element::i32, element::u32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                param_vals);
274
        },
275
        CheckFailure);
276 277 278
}

// Test for failure when we supply the wrong number of replacement shapes.
279
TEST(specialize_function, shape_count_wrong)
280 281 282 283 284 285 286 287 288
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

289 290
    std::vector<void*> param_vals{nullptr, nullptr};

291 292
    ASSERT_THROW(
        {
293
            specialize_function(
294 295
                f,
                {element::f32, element::i32},
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}, PartialShape{4, 5, 6}},
                param_vals);
        },
        CheckFailure);
}

// Test for failure when we supply the wrong number of replacement parameter values.
TEST(specialize_function, value_count_wrong)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                param_vals);
321
        },
322
        CheckFailure);
323
}