compat.cpp 2.6 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2020 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

class CompatOp : public ngraph::op::Op
{
public:
    static constexpr NodeTypeInfo type_info{"CompatOp", 0};
    const NodeTypeInfo& get_type_info() const override { return type_info; }
    CompatOp() = default;

    CompatOp(const Output<Node>& value)
        : Op({value})
    {
    }

    // Test for API compatibility
    bool visit_attributes(AttributeVisitor& visitor) override { return true; }
    std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override
    {
        return make_shared<CompatOp>(new_args.at(0));
    }
    void validate_and_infer_types() override
    {
        auto arg = input_value(0);
        set_output_type(0, arg.get_element_type(), arg.get_shape());
    }

protected:
    // Deprecated method
    virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override
    {
        adjoints.add_delta(input_value(0), input_value(0) * deltas.at(0));
    }
};

constexpr NodeTypeInfo CompatOp::type_info;

TEST(compat, node)
{
    auto param = make_shared<op::Parameter>(element::f32, Shape{10});
    auto c = make_shared<op::Parameter>(element::f32, Shape{10});
    auto x = make_shared<CompatOp>(param);
    auto result = make_shared<op::Result>(x);
    auto f = make_shared<Function>(ResultVector{result}, ParameterVector{param});
    autodiff::Adjoints adjoints({result}, {c});
    auto bprop = adjoints.backprop_output(param);
    ASSERT_TRUE(bprop.get_index() == 0);
    ASSERT_TRUE(is_type<op::v0::Multiply>(bprop.get_node_shared_ptr()));
    set<Output<Node>> params;
    params.insert(bprop.get_node_shared_ptr()->input_value(0));
    params.insert(bprop.get_node_shared_ptr()->input_value(1));
    EXPECT_TRUE(params.count(param) == 1);
    EXPECT_TRUE(params.count(c) == 1);
}