在pycharm和tensorflow環境下運行nmt
阿新 • • 發佈:2018-07-28
light enc -o arm mas 環境 編譯 -- charm
目的是在pycharm中調試nmt代碼,主要做了如下工作:
配置pycharm編譯環境
在File->Settings->Project->Project Interpreter 設置TensorFlow所在的python環境
新建程序主代碼
在nmt文件夾之外新建了nmt_main.py代碼,copy nmt.py的程序入口代碼到其中。如下:
from nmt.nmt import * FLAGS = None root_dir = ‘D:/tensorflow/nmt-master‘ def main(unused_argv): default_hparams = create_hparams(FLAGS) train_fn = train.train inference_fn = inference.inference run_main(FLAGS, default_hparams, train_fn, inference_fn) if __name__ == "__main__": sys.argv = [‘nmt_main.py‘, ‘--src=vi‘, ‘--tgt=en‘, ‘--vocab_prefix=‘ + root_dir + ‘/nmt_data/vocab‘, ‘--train_prefix=‘ + root_dir + ‘/nmt_data/train‘, ‘--dev_prefix=‘ + root_dir + ‘/nmt_data/tst2012‘, ‘--test_prefix=‘ + root_dir + ‘/nmt_data/tst2013‘, ‘--out_dir=‘ + root_dir + ‘/nmt_data/nmt_model‘, ‘--num_train_steps=12000‘, ‘--steps_per_stats=100‘, ‘--num_layers=2‘, ‘--num_units=128‘, ‘--dropout=0.2‘, ‘--metrics=bleu‘] nmt_parser = argparse.ArgumentParser() add_arguments(nmt_parser) # print(nmt_parser) FLAGS, unparsed = nmt_parser.parse_known_args() # print(unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
在pycharm和tensorflow環境下運行nmt