#!/usr/bin/env python3 import argparse import os import sys import tensorflow as tf os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert'] def children(op_name: str, graph: tf.Graph): op = graph.get_operation_by_name(op_name) return set(op for out in op.outputs for op in out.consumers()) def summarize_graph(graph_def): placeholders = dict() outputs = list() graph = tf.Graph() with graph.as_default(): tf.import_graph_def(graph_def, name='') for node in graph.as_graph_def().node: if node.op == 'Placeholder': node_dict = dict() node_dict['type'] = tf.DType(node.attr['dtype'].type).name node_dict['shape'] = str(tf.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1') placeholders[node.name] = node_dict if len(children(node.name, graph)) == 0: if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types: outputs.append(node.name) result = dict() result['inputs'] = placeholders result['outputs'] = outputs return result if __name__ == "__main__": # pragma: no cover sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) from mo.front.tf.loader import load_tf_graph_def parser = argparse.ArgumentParser() parser.add_argument("--input_model", type=str, help="Path to tensorflow model", default="") parser.add_argument('--input_model_is_text', dest='text', help='TensorFlow*: treat the input model file in a text protobuf format instead of ' + 'binary, which is default.', action='store_true', default=False) parser.add_argument('--input_meta', action='store_true', help='TensorFlow*: treat the input model file in a meta graph def format', default=False) parser.add_argument("--input_checkpoint", type=str, help='TensorFlow variables file to load.', default="") parser.add_argument('--saved_model_dir', type=str, default="", help="TensorFlow saved_model_dir") parser.add_argument('--saved_model_tags', type=str, default="", help="Group of tag(s) of the MetaGraphDef to load, in string \ format, separated by ','. For tag-set contains multiple tags, all tags must be passed in.") argv = parser.parse_args() if not argv.input_model and not argv.saved_model_dir: print("[ ERROR ] Please, provide --input_model and --input_model_is_text if needed or --input_dir for saved " \ "model directory") sys.exit(1) if argv.input_model and argv.saved_model_dir: print("[ ERROR ] Both keys were provided --input_model and --input_dir. Please, provide only one of them") sys.exit(1) graph_def = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text, checkpoint=argv.input_checkpoint, model_dir=argv.saved_model_dir, saved_model_tags=argv.saved_model_tags) summary = summarize_graph(graph_def) print("{} input(s) detected:".format(len(summary['inputs']))) for input in summary['inputs']: print("Name: {}, type: {}, shape: {}".format(input, summary['inputs'][input]['type'], summary['inputs'][input]['shape'])) print("{} output(s) detected:".format(len(summary['outputs']))) print(*summary['outputs'], sep="\n")