Commit 877ac969 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

remove parameter check from Function::get_ops() (#834)

* remove parameter check from Function::get_ops()

* create validate pass to hold parameter validation
parent e5c3769d
......@@ -114,6 +114,7 @@ set (SRC
pass/pass.cpp
pass/reshape_elimination.cpp
pass/result_copy_elimination.cpp
pass/validate_graph.cpp
pass/visualize_tree.cpp
pass/core_fusion.cpp
pattern/matcher.cpp
......
......@@ -176,21 +176,7 @@ shared_ptr<Node> Function::get_result() const
std::list<shared_ptr<Node>> Function::get_ops() const
{
std::list<std::shared_ptr<Node>> ops;
traverse_nodes(this, [&](shared_ptr<Node> node) {
ops.push_back(node);
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p)
{
auto it = std::find_if(m_parameters.begin(),
m_parameters.end(),
[p](std::shared_ptr<op::Parameter> q) { return (p == q); });
if (it == m_parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
});
traverse_nodes(this, [&](shared_ptr<Node> node) { ops.push_back(node); });
return ops;
}
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/pass/validate_graph.hpp"
using namespace std;
using namespace ngraph;
bool pass::ValidateGraph::run_on_module(vector<shared_ptr<Function>>& functions)
{
for (shared_ptr<Function> f : functions)
{
validate_parameters(*f);
}
return false;
}
void pass::ValidateGraph::validate_parameters(const Function& function)
{
auto parameters = function.get_parameters();
for (auto node : function.get_ops())
{
shared_ptr<op::Parameter> p = dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p)
{
auto it = find_if(parameters.begin(),
parameters.end(),
[p](shared_ptr<op::Parameter> q) { return (p == q); });
if (it == parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ValidateGraph;
}
}
class ngraph::pass::ValidateGraph : public ModulePass
{
public:
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
void validate_parameters(const Function&);
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment