代码拉取完成,页面将自动刷新
import atexit
import sacred
import argparse
import time
import math
import subprocess
import shutil
import os
import json
import threading
import requests
import glob
from configs import fetch_model_params
import socket
import subprocess
import queue
import sys
import signal
parser = argparse.ArgumentParser()
parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any
parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters
parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)
parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
parser.add_argument('--autostack', action="store_false")
parser.add_argument('--auto_layout', action="store_true")
parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
parser.add_argument('--new', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--predict', action='store_true')
parser.add_argument('--no_delete_tpu', action='store_true')
parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds
args = parser.parse_args()
params = fetch_model_params(args.model)
ex = sacred.Experiment(args.experiment_name)
ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))
def get_open_port(lo=8000, hi=8100):
for i in range(lo, hi):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', i)) != 0:
return i
def train_thread(args, tpu, id, q):
print('starting training on', tpu)
# pass binary flags through
opts = ''
for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:
if args.__getattribute__(flag):
opts += ' --' + flag
for flag in ['autostack', ]:
if not args.__getattribute__(flag):
opts += ' --' + flag
cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)
print('Running:', cmd)
proc = subprocess.Popen(cmd, shell=True)
# poll until it's exited
while proc.poll() is None:
time.sleep(60)
try:
nq, *nargs = q.get_nowait()
if nq == 'kill':
print('train thread recieved kill signal from logging thread')
# first send SIGTERM
proc.terminate()
time.sleep(60)
# if it still hasn't exited, we send SIGKILL
if proc.poll() is None:
print('SIGTERM not successful, sending SIGKILL')
proc.kill()
except queue.Empty:
pass
print('exited training!')
if proc.returncode == 0:
print('exited gracefully')
os.kill(os.getpid(), signal.SIGINT)
return
if args.no_delete_tpu:
print('recreate done, exiting train_thread - not killing tpu!')
return
print("Recreating {} in 60sec...".format(tpu))
time.sleep(60)
os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu))
print('recreate done, exiting train_thread')
# clear out queue
while True:
try:
q.get_nowait()
print('dropped request in queue after pu recreate')
except queue.Empty:
break
def get_json(uri, params=None, timeout=15):
resp = requests.get(uri, params=params, timeout=timeout)
resp.raise_for_status()
return resp.json()
def get_tag_sets(base_uri):
j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})
assert isinstance(j, dict)
return {
run: j[run].keys()
for run in j.keys()
}
def get_scalar_data(base_uri, run, tag):
j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})
assert isinstance(j, list)
return j
def get_run_data(port):
base_uri = f'http://localhost:{port}/'
r = {}
try:
tag_sets = get_tag_sets(base_uri)
runs = tag_sets.keys()
if '.' in runs:
if 'loss' in tag_sets['.']:
r['loss'] = get_scalar_data(base_uri, '.', 'loss')
if 'eval' in runs:
if 'loss' in tag_sets['eval']:
r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')
if 'eval_lambada' in runs:
if 'lambada_acc' in tag_sets['eval_lambada']:
r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')
if 'lambada_log_ppl' in tag_sets['eval_lambada']:
r['lambada_ppl'] = [
[t, s, math.exp(lp)]
for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')
]
except:
import traceback
traceback.print_exc()
return r
@ex.main
def main(_run):
print('Starting run', _run._id)
print('experiment main invoked with argv:', " ".join(sys.argv))
print('WARNING: please remember to remove old metric log files from the model directory.')
os.makedirs('run_configs', exist_ok=True)
shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))
tensorboard_port = get_open_port()
print('Tensorboard at port:', tensorboard_port)
print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))
os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,))
atexit.register(goodbye, _run._id)
curr_step = {}
seen_predictions = set()
heartbeat_timeout = args.initial_heartbeat_timeout * 2
while True:
last_tb_log_time = time.time()
start_time = time.time()
q = queue.Queue()
trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))
trainthd.start()
while trainthd.is_alive():
time.sleep(60)
if start_time + args.initial_heartbeat_timeout < time.time():
# after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower
heartbeat_timeout = args.heartbeat_timeout
print('Polling tensorboard for metrics...')
data = get_run_data(tensorboard_port)
for k in data.keys():
for ts, step, val in data[k]:
if step <= curr_step.get(k, -1):
continue
_run.log_scalar(k, val, step)
if k == 'loss':
_run.log_scalar('tb_ts', ts, step)
print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))
# found something new, so logging!
last_tb_log_time = time.time()
curr_step[k] = step
for f in glob.glob('predictions_{}_*'.format(_run._id)):
if f in seen_predictions:
continue
print('collecting prediction file', f)
ex.add_artifact(f)
seen_predictions.add(f)
# collect eval metrics from jsonl
if os.path.exists(f'eval_{_run._id}.jsonl'):
with open(f'eval_{_run._id}.jsonl') as fh:
for line in fh:
ob = json.loads(line)
val_step = ob['global_step']
val_task = ob['task']
for metr in ob.keys():
k = 'fs.' + val_task + '.' + metr
if metr in ['task', 'global_step']: continue
if val_step <= curr_step.get(k, -1): continue
_run.log_scalar(k, ob[metr], val_step)
curr_step[k] = val_step
if time.time() - last_tb_log_time > heartbeat_timeout:
# the run hasn't logged in a while, so we restart it
q.put(('kill',))
# give training thread some time to do its thing and recreate tpu
while trainthd.is_alive():
print('logging thread waiting for killing stalled run and for tpu recreate to finish')
time.sleep(60)
# reset heartbeat timeout to initial
heartbeat_timeout = args.initial_heartbeat_timeout
last_tb_log_time = time.time()
if args.no_delete_tpu:
break
def goodbye(id):
print("You are now leaving the Python sector.")
print("Sie verlassen den pythonischen Sektor.")
os.system("screen -S tensorboard_{} -X quit".format(id))
if __name__ == '__main__':
for file in glob.glob("**/*", recursive=True):
if file.split('.')[-1] in ['py']:
print('Adding', file, 'to sacred')
ex.add_source_file(file)
ex.add_config({
'tpu_name': args.tpu,
**params
})
ex.run()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。