import json
import argparse
import collections
import numpy as np
import matplotlib
import matplotlib.pyplot as plt


# Part1: log function
def log_info(msg):
    print(f"[INFO]: {msg}")

def log_warn(msg):
    print(f"\033[1;33;40m[WARNING]: {msg}\033[0m")

def log_err(msg):
    print(f"\033[1;31;40m[ERROR]: {msg}\033[0m")

# Part2: Quantization precision error
def cal_rel_err(mlu_loss, baseline_loss):
    min_len = min(len(mlu_loss), len(baseline_loss))
    rel_err = (mlu_loss[:min_len]-baseline_loss[:min_len])/baseline_loss[:min_len]
    rel_err = np.abs(rel_err)
    return rel_err

def cal_abs_err(mlu_loss, baseline_loss):
    min_len = min(len(mlu_loss), len(baseline_loss))
    abs_err = np.abs(mlu_loss[:min_len]-baseline_loss[:min_len])
    return abs_err

#Calculate the degree of overlap between the loss curve of the model to be tested and the loss curve of baseline model.
# When they completely overlap, the value is 1; when they do not overlap at all, the value is 0.
def cal_R_square(mlu_loss, baseline_loss):
    min_len = min(len(mlu_loss), len(baseline_loss))
    mlu_loss_arr = np.array(mlu_loss[:min_len])
    baseline_loss_arr = np.array(baseline_loss[:min_len])
    r = np.corrcoef(mlu_loss_arr, baseline_loss_arr)[1, 0]
    return r**2

def  cal_var(mlu_loss, baseline_loss):
    min_len = min(len(mlu_loss), len(baseline_loss))
    diff_value = mlu_loss[:min_len]-baseline_loss[:min_len]
    variance = np.var(diff_value)
    return variance

def check_loss(mlu_loss, baseline_loss, avg_rel_err_threshold):
    min_len = min(len(mlu_loss), len(baseline_loss))
    fail_num=0
    rel_err_pool = []

    for i in range(min_len):
        abs_err = abs(mlu_loss[i]-baseline_loss[i])
        rel_err = abs_err/baseline_loss[i]
        rel_err_pool.append(rel_err)

        if  baseline_loss[i] >= 1:
            if rel_err > 0.03:
                fail_num += 1
        else:
            if abs_err > 0.02:
                fail_num += 1
    if fail_num <= min_len * 0.01 and \
       (sum(rel_err_pool) / len(rel_err_pool)) <= avg_rel_err_threshold:
        return True
    else:
        log_err(f'Maybe the number of loss with anomalies({fail_num}) exceeds 2% of the total({min_len}), or avg rel err of loss({sum(rel_err_pool)/len(rel_err_pool)}) exceeds 1%.')
        return False

# Part3: draw loss    
def configure_plt_pattern():
    plt.rcParams['savefig.dpi']=300 # save img pixels
    #plt.rcParams['figure.figsize']=(5.0, 5.0)
    plt.grid()

def draw_multi_lines(mlu_data, baseline_data, title, xlabel, ylabel, save_path):
    min_len = min(len(mlu_data), len(baseline_data))
    mlu_steps = [i for i in range(len(mlu_data))][0:min_len]
    baseline_steps = [i for i in range(len(baseline_data))][0:min_len]
    mlu_data = mlu_data[0:min_len]
    baseline_data = baseline_data[0:min_len]

    plt.figure()
    configure_plt_pattern()
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)

    plt.plot(mlu_steps, mlu_data, label="MLU", linewidth=1)
    plt.plot(baseline_steps, baseline_data, label="Baseline", linewidth=1)

    plt.legend()
    plt.ion()
    plt.savefig(save_path)
    
def draw_single_line(data, title, xlabel, ylabel, save_path):
    mlu_steps = [i for i in range(len(data))]

    plt.figure()
    configure_plt_pattern()
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)

    plt.plot(mlu_steps, data, linewidth=1)

    plt.ion()
    plt.show()
    plt.savefig(save_path)


# Part4: parse the logs generated by megatron
def parse_log(log_file):
    target_str = "elapsed time per iteration"
    mfu_str = "best_mfu: "
    tgs_str = "best_tgs: "
    best_mfu = None
    best_tgs = None
    log_dict =  collections.OrderedDict()
    msg_fn = lambda flag: line.split(flag)[1].strip().split(" |")[0]

    with open(log_file, "rb") as fr:
        for line in fr.readlines():
            line = str(line).strip()
            if line.find(target_str) != -1:
                #overflow step, skip it
                if len(line.split("lm loss")) == 1:
                    continue

                step = int(line.split("/")[0].split(" ")[-1])
                elapsed_time = float(msg_fn("elapsed time per iteration (ms): "))
                lr = float(msg_fn("learning rate: "))
                global_batch_size = int(msg_fn("global batch size:"))
                lm_loss = float(msg_fn("lm loss: "))
                loss_scale = float(msg_fn("loss scale: "))
                grad_norm = float(msg_fn("grad norm: "))
                skipped_iters = int(msg_fn("number of skipped iterations: "))
                nan_iters = int(msg_fn("number of nan iterations: "))
                log_data = [elapsed_time, lr, global_batch_size, lm_loss, loss_scale, grad_norm,
                            skipped_iters, nan_iters]
                log_dict[step] = log_data
            elif line.find(mfu_str) != -1:
                best_mfu = float(line.strip().split(": ")[1][:-3])
            elif line.find(tgs_str) != -1:
                best_tgs = float(line.strip().split(": ")[1][:-3])
    return [log_dict, best_mfu, best_tgs]
