import os
import argparse
import collections
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from utils import log_info, log_warn, log_err
from utils import cal_rel_err, check_loss
from utils import draw_multi_lines, draw_single_line
from utils import parse_log

def get_args():
    parser = argparse.ArgumentParser(description='Parse Arguments')
    parser.add_argument('--parse_type', type=str, default='all',
                        choices=['loss', 'perf', 'all', 'none'])
    parser.add_argument('--platform_mode', type=str, default='online',
                        choices=['online', 'ONLINE', 'offline', 'OFFLINE'],
                        help="baselinelogs depends on platform mode")
    parser.add_argument('--avg_rel_err_threshold', type=float, default=0.02,
                        help="relative error threshold")
    parser.add_argument('--abs_err_threshold', type=float, default=0.1,
                        help="absolute error threshold")
    parser.add_argument('--perf_tolerance', type=float, default=0.03,
                        help="performance tolerance")
    parser.add_argument('--baseline_logs', type=str, default="./tests/demo_tests/baseline_logs",
                        help="baseline logs path")
    parser.add_argument('--mlu_logs', type=str, default="./tests/demo_tests/logs",
                        help="mlu logs path")
    args = parser.parse_args()
    return args

def update_args(args):
    if args.platform_mode.lower() == "online":
        args.baseline_logs = "/workspace/dataset/favorite/soft-data-platform/v1/models/demo_tests/model_baseline_logs"
    elif args.platform_mode.lower() == "offline":
        args.baseline_logs = "/data/platform/models/demo_tests/model_baseline_logs"

    return args



def show_loss_msg(parse_type, mlu_loss, baseline_loss, model_name):
    rel_err = cal_rel_err(mlu_loss, baseline_loss)
    if not check_loss(mlu_loss, baseline_loss, args.avg_rel_err_threshold):
        loss_title = f"{model_name}-loss.png"
        rel_err_title = f"{model_name}-rel-err-loss.png"
        loss_save_path = os.path.join(os.path.dirname(__file__), loss_title)
        rel_err_save_path = os.path.join(os.path.dirname(__file__), rel_err_title)

        #draw image to compare loss curves
        draw_multi_lines(mlu_loss, baseline_loss, loss_title, "steps", "loss",
                         loss_save_path)
        #draw image with relative errors
        draw_single_line(rel_err, rel_err_title, "steps", "loss_rel_err",
                         rel_err_save_path)
        return False
    else:
        return True

def show_perf_msg(mlu_perf, baseline_perf, model_name):
    mlu_time = mlu_perf[2]
    baseline_time = baseline_perf[2]
    avg_mlu_time = round(np.mean(mlu_time), 2)
    avg_baseline_time = round(np.mean(baseline_time), 2)

    avg_time_diff = avg_mlu_time - avg_baseline_time
    time_growth_percent = avg_time_diff/avg_baseline_time
    if time_growth_percent > args.perf_tolerance:
        log_err(f'[Perf]: elapsed time increased by {round(100*time_growth_percent, 2)}%, and exceeds the threshold(3%)')
        perf_title = f"{model_name}-elapsed-time"
        perf_save_path = os.path.join(os.path.dirname(__file__), perf_title)
        draw_multi_lines(mlu_time[1:], baseline_time[1:], perf_title,
                         "steps", "elapsed_time(ms)", perf_save_path)
        return False
    else:
        return True


def main(args):
    failed_loss_cases = []
    failed_perf_cases = []

    for mlu_log_file in os.listdir(args.mlu_logs):
        model_type = mlu_log_file.strip().split("-")[1]
        model_scale= mlu_log_file.strip().split("-")[2]
        model_name = model_type + "-" + model_scale # llama2-7B
        baseline_log_file = os.path.join(args.baseline_logs, "log-"+model_name+".txt")
        if not os.path.exists(baseline_log_file):
            log_err(f"baseline log file({baseline_log_file}) does not exist")

        print(f"{model_name} message:")
        baseline_msg, baseline_mfu, baseline_tgs = parse_log(baseline_log_file)
        mlu_msg, mlu_mfu, mlu_tgs = parse_log(os.path.join(args.mlu_logs, mlu_log_file))
        if len(mlu_msg) == 0:
            log_err(f"invalid mlu log file, no loss data was parsed, mlu log file is {mlu_log_file}")
            print("")
            failed_loss_cases.append(mlu_log_file)
            failed_perf_cases.append(mlu_log_file)
            continue

        if args.parse_type in ['loss', 'all']:
            mlu_loss = np.array([log[3] for log in mlu_msg.values()])
            baseline_loss = np.array([log[3] for log in baseline_msg.values()])
            status = show_loss_msg(args.parse_type, mlu_loss, baseline_loss, model_name)
            if status:
                log_info(f"{model_name}'s loss is normal.")
            else:
                failed_loss_cases.append(mlu_log_file)
        if args.parse_type in ['perf', 'all']:
            mlu_time = np.array([log[0] for log in mlu_msg.values()])
            baseline_time = np.array([log[0] for log in baseline_msg.values()])
            mlu_perf = [mlu_mfu, mlu_tgs, mlu_time]
            baseline_perf = [baseline_mfu, baseline_tgs, baseline_time]
            status = show_perf_msg(mlu_perf, baseline_perf, model_name)
            if status:
                log_info(f"{model_name}'s performance is normal.")
            else:
                failed_perf_cases.append(mlu_log_file)
        if args.parse_type in ['none']:
            log_info(f"do nothing")

    print("")
    print("Summary:")
    log_info(f"{len(os.listdir(args.mlu_logs))} models tested")
    if failed_loss_cases:
        log_err(f"{len(failed_loss_cases)} models loss is abnormal, list: {failed_loss_cases}")
    if failed_perf_cases:
        log_err(f"{len(failed_perf_cases)} models perf is abnormal, list: {failed_perf_cases}")
    assert len(failed_loss_cases)+len(failed_perf_cases) == 0, "The loss or perf of some models are abnormal."
    if len(failed_loss_cases)+len(failed_perf_cases) == 0:
        log_info(f"All the models passed the test.")


if __name__ == "__main__":
    args = get_args()
    args = update_args(args)
    main(args)
