//***************************************************************************** // Copyright 2017-2019 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** #include <getopt.h> #include <iostream> #include <memory> #include "ngraph/file_util.hpp" #include "ngraph/runtime/plaidml/plaidml_backend.hpp" #include "ngraph/serializer.hpp" static const struct option opts[] = {{"backend", required_argument, nullptr, 'b'}, {"format", required_argument, nullptr, 'f'}, {"help", no_argument, nullptr, 'h'}, {nullptr, 0, nullptr, '\0'}}; int main(int argc, char** argv) { int opt; bool err = false; bool usage = false; std::string model; std::string output; std::string backend_name = "PlaidML"; plaidml_file_format format = PLAIDML_FILE_FORMAT_TILE; while ((opt = getopt_long(argc, argv, "f:b:h", opts, nullptr)) != -1) { switch (opt) { case 'b': backend_name = optarg; break; case 'h': usage = true; break; case 'f': if (!strcmp(optarg, "tile")) { format = PLAIDML_FILE_FORMAT_TILE; } else if (!strcmp(optarg, "human")) { format = PLAIDML_FILE_FORMAT_STRIPE_HUMAN; } else if (!strcmp(optarg, "prototxt")) { format = PLAIDML_FILE_FORMAT_STRIPE_PROTOTXT; } else if (!strcmp(optarg, "binary")) { format = PLAIDML_FILE_FORMAT_STRIPE_BINARY; } else { err = true; } break; case '?': default: err = true; break; } } if (optind + 2 != argc) { err = true; } else { model = argv[optind]; output = argv[optind + 1]; if (model.empty()) { err = true; } else if (!ngraph::file_util::exists(model)) { std::cerr << "File " << model << " not found\n"; err = true; } if (output.empty()) { err = true; } else if (ngraph::file_util::exists(output)) { std::cerr << "File " << output << " already exists; not overwriting\n"; err = true; } } if (backend_name.substr(0, backend_name.find(':')) != "PlaidML") { std::cerr << "Unsupported backend: " << backend_name << "\n"; err = true; } if (err || usage) { std::cerr << R"###( DESCRIPTION Convert an ngraph JSON model to one of PlaidML's file formats. SYNOPSIS ngraph-to-plaidml [--backend|-b <backend>] MODEL OUTPUT OPTIONS -b|--backend Backend to use (default: PlaidML) -f|--format Format to use (tile, human, prototxt, binary, or json; default: tile) )###"; } if (err) { return EXIT_FAILURE; } if (usage) { return EXIT_SUCCESS; } std::cerr << "Reading nGraph model from " << model << "\n"; std::shared_ptr<ngraph::Function> f = ngraph::deserialize(model); std::shared_ptr<ngraph::runtime::Backend> base_backend = ngraph::runtime::Backend::create(backend_name); std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Backend> backend = std::dynamic_pointer_cast<ngraph::runtime::plaidml::PlaidML_Backend>(base_backend); if (!backend) { std::cerr << "Failed to load PlaidML backend\n"; return EXIT_FAILURE; } auto exec = backend->compile(f); static_cast<ngraph::runtime::plaidml::PlaidML_Executable*>(exec.get())->save(output, format); std::cerr << "Wrote output to " << output << "\n"; return EXIT_SUCCESS; }