Tensorflow tf.app.run()与命令行参数解析   2018-10-23


tf.app.run() 与 命令行参数解析 tf.flags

首先给出一段常见的代码:

if __name__ == '__main__':
tf.app.run()

找到 Tensorflow 中关于上述 函数run() 的源码:

def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS

# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None

# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access

main = main or _sys.modules['__main__'].main

# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(_sys.argv[:1] + flags_passthrough))


_allowed_symbols = [
'run',
# Allowed submodule.
'flags',
]

remove_undocumented(__name__, _allowed_symbols)

可以看到源码中的过程是首先加载 flags 的参数项,然后执行 main 函数。参数是使用tf.app.flags.FLAGS 定义的。

tf.app.flags.FLAGS

关于 tf.app.flags.FLAGS 的使用:

# fila_name: temp.py
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

print('string: ', FLAGS.string)
print('learning_rate: ', FLAGS.learning_rate)
print('flag: ', FLAGS.flag)

输出:

('string: ', 'train')
('learning_rate: ', 0.001)
('flag: ', True)

Reference


分享到:


  如果您觉得这篇文章对您的学习很有帮助, 请您也分享它, 让它能再次帮助到更多的需要学习的人. 您的支持将鼓励我继续创作 !
本文基于署名4.0国际许可协议发布,转载请保留本文署名和文章链接。 如您有任何授权方面的协商,请邮件联系我。

Contents

  1. tf.app.flags.FLAGS
  2. Reference