# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.
import os
from functools import wraps

import torch
try:
    import torch_mlu
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
    from torch_mlu.utils.model_transfer import transfer
except:
    pass

from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_wandb_writer
from .global_vars import get_one_logger
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize  import initialize_megatron
from .training import pretrain, get_model, get_train_valid_test_num_samples

from megatron import addons

from .utils import (print_rank_0,
                    is_last_rank,
                    print_rank_last)

if hasattr(torch.distributed, "all_gather_into_tensor") and \
   hasattr(torch.distributed, "reduce_scatter_tensor"):
    torch.distributed._all_gather_base = torch.distributed.all_gather_into_tensor
    torch.distributed._reduce_scatter_base = torch.distributed.reduce_scatter_tensor

def wrapper_type(fn):
    @wraps(fn)
    def decorated(*args, **kwargs):
        output = fn(*args, **kwargs)
        if isinstance(output, str):
            if output == 'torch.mlu.FloatTensor':
                output = 'torch.cuda.FloatTensor'
            elif output == 'torch.mlu.BFloat16Tensor':
                output = 'torch.cuda.BFloat16Tensor'
            elif output == 'torch.mlu.HalfTensor':
                output = 'torch.cuda.HalfTensor'
        return output

    return decorated

os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
torch.Tensor.type = wrapper_type(torch.Tensor.type)
