# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Sample Generate GPT"""
import shutil
import logging
import subprocess

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))
import socket

from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.core import mpu
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
# from megatron.core.models.gpt import GPTModel
from megatron.legacy.model import GPTModel
from megatron.training import get_model
from megatron.training.arguments import core_transformer_config_from_args
from megatron.inference.text_generation_server import MegatronServer
from megatron.inference.text_generation import generate_and_post_process
from megatron.inference.text_generation import beam_search_and_post_process


import torch
from torch import distributed as dist

import torch_mlu

from flask import Flask, request, jsonify
from tqdm import tqdm
import time

from preprocess import postprocess, save_answer, mmlu_preprocess, agi_preprocess, ceval_preprocess


def model_provider(pre_process=True, post_process=True):
    """Build the model."""
    config = core_transformer_config_from_args(get_args())
    model = GPTModel(
        config,
        num_tokentypes=0,
        parallel_output=False,
        pre_process=pre_process,
        post_process=post_process
        )
    return model

def add_text_generate_args(parser):
    group = parser.add_argument_group(title='text generation')
    group.add_argument("--task", type=str, default=None,
                       help='evaluation task')
    group.add_argument("--eval-data", type=str, default=None,
                       help='evaluation data path')
    group.add_argument("--save-answer", type=str, default=None,
                       help='path of model genereated answer')
    return parser

def print_flush(prev_str, curr_str):
    difference = ''.join([char2 for char1, char2 in zip(prev_str, curr_str) if char1 != char2])

    if len(prev_str) < len(curr_str):
        difference += curr_str[len(prev_str):]

    sys.stdout.write(difference)

def init_model():
    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()
    print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text "
                 "generation.")
    args.exit_on_missing_checkpoint = True
    # Set up model and load checkpoint
    model = get_model(model_provider, wrap_with_ddp=False)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

    return model

def gene(model,prompts, tokens_to_generate):
    logprobs = False
    if tokens_to_generate == 0 and not logprobs:
        return "tokens_to_generate=0 implies logprobs should be True"
    
    temperature = 1.0
    top_k = 1
    top_p = 0.0
    top_p_decay = 0.0 
    top_p_bound = 0.0
    add_BOS = False
    stop_on_double_eol = False
    stop_on_eol = False
    prevent_newline_after_colon = False
    random_seed = 1234
    no_log = False
    beam_width = 1
    stop_token=50256
    length_penalty = 1 


    choice = torch.mlu.LongTensor(1)
    torch.distributed.broadcast(choice, 0)


    if not beam_width:
        try:
            response, response_seg, response_logprobs, _ = \
                        generate_and_post_process(
                        model,
                        prompts=prompts,
                        tokens_to_generate=tokens_to_generate,
                        return_output_log_probs=logprobs,
                        top_k_sampling=top_k,
                        top_p_sampling=top_p,
                        top_p_decay=top_p_decay,
                        top_p_bound=top_p_bound,
                        temperature=temperature,
                        add_BOS=add_BOS,
                        use_eod_token_for_early_termination=True,
                        stop_on_double_eol=stop_on_double_eol,
                        stop_on_eol=stop_on_eol,
                        prevent_newline_after_colon=prevent_newline_after_colon,
                        random_seed=random_seed)

            return {"text": response,
                        "segments": response_seg,
                        "logprobs": response_logprobs}
        except ValueError as ve:
            return ve.args[0]
    else:
        try:
            response, response_seg, response_scores = \
                        beam_search_and_post_process(
                        model,
                        prompts=prompts,
                        tokens_to_generate=tokens_to_generate,
                        beam_size = beam_width,
                        add_BOS=add_BOS,
                        stop_token=stop_token,
                        num_return_gen=beam_width,  # Returning whole beam
                        length_penalty=length_penalty,
                        prevent_newline_after_colon=prevent_newline_after_colon
                        )
            return {"text": response,
                        "segments": response_seg,
                        "scores": response_scores}
        except ValueError as ve:
            return ve.args[0]

def evaluation(data_path, task, answer_path):
    file_list = os.listdir(data_path)

    acc_list = []

    for i, file in enumerate(file_list):
        file_path = os.path.join(data_path, file)
        if task == "mmlu":
            prompt, label = mmlu_preprocess(file_path)
        elif task == "agi":
            prompt, label = agi_preprocess(file_path)
        elif task == "ceval":
            prompt, label = ceval_preprocess(file_path)

        answers = []
        t1 = time.time()

        for j in range(len(prompt)):

            sentence = prompt[j]
            input_length = len(sentence)
            tokens_to_generate = 5 #5 #100
            response = gene(model, [sentence],tokens_to_generate)
            if not response:
                print(f"Error {response}")
                ans = "no"
            else:
                try:
                    ans = response['text'][0][input_length:]
                except:
                    print_rank_0(response)
                    ans = 'not get answer'
                # print_rank_0(ans)
                answers.append(ans)

        test_time = time.time()-t1

        if answer_path:
            save_path = answer_path + file.split('.')[0]
            save_answer(answers, save_path)

        acc, model_answer = postprocess(answers, label)
        acc_list.append(acc)

        log_string = '|  num: {}/{}  |'.format(i+1, len(file_list))
        log_string += '  acc: {:.4f}  |  subject name: {}  |'.format(acc, file.split('.')[0])
        log_string += '  evaluation time: {:.2f}s  |'.format(test_time)
        
        if torch.distributed.get_rank() == 0:
            print(log_string)

    avg_acc = sum(acc_list)/len(acc_list)
    if torch.distributed.get_rank() == 0:
        print('===================average accuracy',avg_acc )

def chat(model):
    """Interactive dialog mode with multiple rounds of conversation"""
    # system_template = "Below is an instruction that describes a task, paired with an input that provides further " \
    #                     "context. Write a response that appropriately completes the request. " \
    #                     "Please note that you need to think through your response logically and step by step.\n\n"
    # dialog_template = "### Instruction:\n{instruction}\n\n### Response:"

    system_template = "Below is an instruction that describes a task, Write a response that appropriately completes the request. " \
                        "Please note that you need to think through your response logically and step by step." \
                        "No other things needed besided the response."
    dialog_template = "Instruction:\n{instruction}\nResponse:"


    def get_context(content):
        res = system_template
        for q, r in content:
            if r is None:
                res += dialog_template.format(instruction=q)
            else:
                res += dialog_template.format(instruction=q) + r
        return res

    histories = []
    columns, rows = shutil.get_terminal_size()
    output, prompt, instruction = "", "", ""
    input_template, response_template = "\n\nYou >> ", "\nModel >>\n"
    command_clear = ["clear"]
    messages = []
        
    while True:
        terminate_runs = torch.zeros(1, dtype=torch.int64, device=torch.mlu.current_device())

        if dist.get_rank() == 0:
            if not histories:
                logging.info("===========================================================")
                logging.info("1. If you want to quit, please entry one of [q, quit, exit]")
                logging.info("2. To create new title, please entry one of [clear, new]")
                logging.info("===========================================================")

            prompt = input(input_template)
            # remove non utf-8 characters
            prompt = prompt.encode('utf-8', errors='ignore').decode('utf-8')
            if prompt.strip() in ["q", "exit", "quit"]:
                terminate_runs += 1

            if prompt.strip() in ["clear", "new"]:
                subprocess.call(command_clear)
                histories = []
                messages = []
                continue

            if not prompt.strip():
                continue

            histories.append((prompt, None))
            instruction = get_context(histories)
            histories.pop()
            messages.append(
                {"role": "user", "content": prompt}
            )


        dist.all_reduce(terminate_runs)
        dist.barrier()
        if terminate_runs > 0:
            break

        input_length = len(instruction)
        response = gene(model, [instruction],100)
        response = response['text'][0][input_length:].rstrip()

        if dist.get_rank() == 0:
            sys.stdout.write(response_template)
            sys.stdout.write(response)


        histories.append((prompt, response))
        messages.append(
            {"role": "assistant", "content": response}
        )

if __name__ == "__main__":
    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'no_load_rng': True,
                                       'no_load_optim': True})
    args = get_args()
    model = init_model()

    data_path =  args.eval_data
    task = args.task
    
    if task == "chat":
        chat(model)
    elif task in ['mmlu', 'agi', 'ceval']:
        answer_path = args.save_answer
        evaluation(data_path, task, answer_path)
    else:
        print("=================task not support")