# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.

"""General utilities."""
import importlib.metadata as metadata
import torch
try:
    import cntrainkit
except:
    cntrainkit = None

from megatron.training import (
    get_args,
    get_adlr_autoresume,
)
from megatron.addons.function_wrapper import FUNCTION_WRAPPER

apex_minimum_required_version = '0.1'
fa_minimum_required_version = '2.4.3'
te_minimum_required_version = '1.4.0'
cntrainkit_minimum_required_version = '0.1.0'


# new feature: adapt autoresume
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.training.utils.check_adlr_autoresume_termination')
def check_adlr_autoresume_termination(iteration, model,
                                      optimizer, opt_param_scheduler):
    """Check for autoresume signal and exit if it is received."""
    from megatron.training.checkpointing import save_checkpoint

    args = get_args()
    autoresume = get_adlr_autoresume()
    autoresume.update(iteration, model, optimizer, opt_param_scheduler, save_checkpoint)
    # Add barrier to ensure consistnecy.
    torch.distributed.barrier()
    if autoresume.termination_requested():
        if args.save:
            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
        print_rank_0(">>> autoresume termination request found!")
        if torch.distributed.get_rank() == 0:
            autoresume.request_resume()
        print_rank_0(">>> training terminated. Returning")
        sys.exit(0)

def check_package_version():
    orig_version_fn = lambda version: version.split('+')[0]

    apex_version = metadata.version('apex')
    fa_version = metadata.version('flash_attn')
    te_version = metadata.version('transformer_engine')
    if cntrainkit != None:
        cntrainkit_version = metadata.version('cntrainkit')
    apex_orig_version = orig_version_fn(apex_version)
    fa_orig_version = orig_version_fn(fa_version)
    te_orig_version = orig_version_fn(te_version)


    assert apex_orig_version >= apex_minimum_required_version, \
        f'apex version({apex_orig_version}) must be at least {apex_minimum_required_version}'
    assert fa_orig_version >= fa_minimum_required_version, \
        f'flash_attn version({fa_orig_version}) must be at least {fa_minimum_required_version}'
    assert te_orig_version >= te_minimum_required_version, \
        f'transformer_engine version({te_orig_version}) must be at least {te_minimum_required_version}'
    if cntrainkit != None:
        assert cntrainkit_version >= cntrainkit_minimum_required_version, \
            f'cntrainkit version({cntrainkit_version}) must be at least {cntrainkit_minimum_required_version}'
