import os
import glob
import csv
import json
import torch
import argparse
import torch_mlu

#Check whether dependencies are installed
import cndb
import apex
import eight
import flash_attn
import transformer_engine



PLATFORM_MODE = ["ONLINE", "OFFLINE"]

csv_column = [
        'model', 'framework', 'device_name', 'cpu', 'device_num_per_node', 'node_num',
        'tp', 'pp', 'dp', 'seq_len', 'head_num', 'hidden_size', 'micro_bs', 'global_bs',
        'e2e_iteration_time_sec', 'MFU', 'HFU', 'TGS'
]

def params_check(common_options):
    nproc = common_options['nproc_per_node']
    nnodes = common_options['nnodes']
    world_size = nproc*nnodes
    device_count = torch.mlu.device_count()
    assert nnodes==1, "Only support single node, now nnodes is {nnodes}"
    assert world_size<=device_count, "world_size{world_size} exceeds device count{device_count}"

def sanity_check(common_options, tp, pp):
    nproc = common_options['nproc_per_node']
    nnodes = common_options['nnodes']
    world_size = nproc*nnodes
    if world_size % (tp*pp) == 0:
        return True
    else:
        print(f"[Error]: tp{tp} * pp{pp} exceeds world_size(world_size)")
        return False

def train_status_check(log_file):
    oom_flag = "OutOfMemoryError"
    status="OK"
    with open(log_file, 'rb') as fr:
        for line in fr.readlines():
            if str(line).find(oom_flag) != -1:
                status="OOM_Error"
                break
    return status


def load_json(json_file):
    print(f"json_file: {json_file}")
    with open(json_file, 'r') as fr:
        json_data = json.load(fr)
        return json_data

def get_latest_log(dirname):
    files = glob.glob(os.path.join(dirname, '*'))
    if not files:
        return None
    files.sort(key=os.path.getmtime, reverse=True)
    return files[0]

def generate_perf_data(log,
        model_name,
        root_dir,
        save_dir=None,
        nproc_per_node = None,
        nnodes = None,
        train_mode="pretrain",
        framework_name="Megatron"):
    cmd = f"cndb_autogen --data-file {log} --model-name {model_name} --train-mode {train_mode}"
    cmd += f" --framework-name {framework_name} --framework-root-dir {root_dir} "
    cmd += f" --log-level INFO "
    cmd += f" --tag benchmark_submit "
    if save_dir:
        cmd += f" --save-path {save_dir}"
    if nproc_per_node and nnodes:
        cmd += f" --device-num-per-node {nproc_per_node} --node-num {nnodes}"
    os.system(cmd)

def flatten_dict(data, parent_key=''):
    new_dict = {}
    for k, v in data.items():
        if isinstance(v, dict):
            new_dict.update(flatten_dict(v, k))
        else:
            if parent_key and parent_key == 'framework' and k == 'name':
                new_dict['framework'] = v
            elif parent_key and parent_key == 'device_info' and k == 'name':
                new_dict['device_name'] = v
            else:
                new_dict[k] = v

    return new_dict

def generate_perf_data_csv(data_dir):
    files = glob.glob(os.path.join(data_dir, "*.json"))
    if not files:
        print("There are no data files in the path:{}".format(data_dir))
        return
    rows = []
    for file in files:
        value = json_to_csv(file)
        rows.append(value)
    path = os.path.join(data_dir, 'csv')
    if not os.path.exists(path):
        os.mkdir(path)
    with open(os.path.join(path, 'perf.csv'), 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=csv_column)
        writer.writeheader()
        writer.writerows(rows)

def json_to_csv(json_file):
    data = load_json(json_file)
    new_data = flatten_dict(data)
    select_data = {k : v for k,v in new_data.items() if k in csv_column}
    return select_data

def submit_data(data_dir):
    files = glob.glob(os.path.join(data_dir, "*.json"))
    if not files:
        print("There are no data files in the path:{}".format(data_dir))
        return
    for file in files:
        cmd = f"cndb_submit --data-file {file} --log-level INFO"
        os.system(cmd)

def launch_benchmark(model_config, args):
     common_options = model_config['common_options']
     scripts_path = model_config['scripts_path']
     model_options = model_config['model_options']
     nproc_per_node = common_options["nproc_per_node"]
     nnodes = common_options["nnodes"]

     params_check(common_options)

     for model_name, train_pattern in model_options.items():
         #parse json data
         micro_batch = train_pattern['mb']
         accu_steps = train_pattern['accu']
         tp_size = train_pattern['tp']
         pp_size = train_pattern['pp']
         seq_len = train_pattern['seq']
         train_mode = train_pattern['train_mode']

         #launch benchark
         for seq in seq_len:
             for tp in tp_size:
                 for pp in pp_size:
                     for accu in accu_steps:
                         for mb in micro_batch:
                             if not sanity_check(common_options, tp, pp):
                                 break
                             scripts_name=scripts_path[model_name]
                             script_path = os.path.dirname(__file__)
                             f = eight.Frame()
                             f.cleanFile()
                             cmd = f"bash {script_path}/{scripts_name} {args.platform_mode} {mb} {accu} {tp} {pp} {seq}"
                             print(f"[INFO]: train commands: {cmd}")
                             os.system(cmd)
                             # auto-generate data file and decision whether to submit data.
                             log = get_latest_log(f"{script_path}/logs")
                             generate_perf_data(log, model_name, script_path, args.save_dir,
                                                nproc_per_node, nnodes, train_mode)


def parse_args():
    parser = argparse.ArgumentParser(description="Benchmark params")
    parser.add_argument('--save-dir', default="/tmp/cndb_data",
                        help="Where to save the auto-generated data files. The default path is /tmp/cndb_data")
    parser.add_argument('--submit', action="store_true",
                        help="Whether to submit data to database. when the param presents, "+
                        "it will be true which means data will be submitted to database.")
    parser.add_argument('--platform-mode', choices=PLATFORM_MODE, default="ONLINE",
                        help="Which platform to run this benchmark. The default is ONLINE")

    return parser.parse_args()

def main():
    args = parse_args()

    model_file="model_options.json"
    model_config = load_json(os.path.join(os.path.dirname(__file__), model_file))
    launch_benchmark(model_config, args)

    # generate csv file. the csv will produced in data_dir/csv/perf.csv
    generate_perf_data_csv(args.save_dir)

    # whether to submit data to superset.
    if args.submit:
        submit_data(args.save_dir)

if __name__ == '__main__':
    main()
