# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.
from abc import ABC
from abc import abstractmethod

from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer

from megatron.training.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer
from megatron.training.tokenizer.gpt2_tokenization import GPT2Tokenizer
from transformers import PreTrainedTokenizerFast, AutoTokenizer
from megatron.addons.function_wrapper import FUNCTION_WRAPPER

from transformers import PreTrainedTokenizerFast, AutoTokenizer

from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding, _GPT2BPETokenizer, _Llama2Tokenizer, _Llama3Tokenizer, _BertWordPieceTokenizer, _SentencePieceTokenizer, _GPTSentencePieceTokenizer, _NullTokenizer
class _HFTokenizer(MegatronTokenizer):
    """AutoTokenizer for Hf Pretrained model loading."""

    def __init__(self, tokenizer_name_or_path, vocab_extra_ids):
        name = tokenizer_name_or_path
        super().__init__(name)
        hf_tokenizer_kwargs = {}
        if vocab_extra_ids > 0:
            hf_tokenizer_kwargs["additional_special_tokens"] = [f"<extra_id_{_id}>" for _id in range(vocab_extra_ids)]
        
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.encoder = self.tokenizer.get_vocab()
        self.decoder = {v: k for k, v in self.encoder.items()}

    @property
    def vocab_size(self):
        return self.tokenizer.vocab_size

    @property
    def vocab(self):
        return {
            **{special_token: self.tokenizer.convert_tokens_to_ids(special_token) for special_token in self.tokenizer.additional_special_tokens},
            **self.tokenizer.vocab,
        }

    @property
    def inv_vocab(self):
        return {v: k for k, v in self.vocab.items()}

    def tokenize(self, text):
        return self.tokenizer.encode(text)

    def detokenize(self, token_ids):
        return self.tokenizer.decode(token_ids)

    @property
    def eod(self):
        return self.eos

    @property
    def cls(self):
        candidate = self.tokenizer.cls_token_id
        return self._check_token_candidate(candidate)

    @property
    def sep(self):
        candidate = self.tokenizer.sep_token_id
        return self._check_token_candidate(candidate)

    @property
    def pad(self):
        candidate = self.tokenizer.pad_token_id
        if candidate is None:
            candidate = self.tokenizer.eos_token_id
        return self._check_token_candidate(candidate)

    @property
    def mask(self):
        candidate = self.tokenizer.mask_token_id
        return self._check_token_candidate(candidate)

    @property
    def bos(self):
        candidate = self.tokenizer.bos_token_id
        return self._check_token_candidate(candidate)

    @property
    def eos(self):
        candidate = self.tokenizer.eos_token_id
        return self._check_token_candidate(candidate)

    @property
    def additional_special_tokens_ids(self):
        """ All the additional special tokens you may want to use (list of strings)."""
        return self.tokenizer.additional_special_tokens_ids

    @staticmethod
    def _check_token_candidate(candidate):
        if candidate is None:
            raise AttributeError("Token doesn't exist")
        return candidate

class _PreTrainedTokenizerFast(MegatronTokenizer):
    def __init__(self, tokenizer_path):
        super().__init__(tokenizer_path)
        self.tokenizer = PreTrainedTokenizerFast.from_pretrained(
            tokenizer_path, legacy=False, from_slow=False)

    def tokenize(self, text):
        text_tokens = self.tokenizer.tokenize(text)
        return self.tokenizer.convert_tokens_to_ids(text_tokens)

    @property
    def vocab(self):
        return self.tokenizer.vocab

    @property
    def vocab_size(self):
        return len(self.tokenizer.vocab)

    @property
    def inv_vocab(self):
        return self.tokenizer._decode

    @property
    def eod(self):
        return self.tokenizer.eos_token_id

#new feature: llama3 tokenizer
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.training.tokenizer.tokenizer.build_tokenizer')
def build_tokenizer(args):
    """Initialize tokenizer."""
    if args.rank == 0:
        print('> building {} tokenizer ...'.format(args.tokenizer_type),
              flush=True)

    # Select and instantiate the tokenizer.
    if args.tokenizer_type == 'BertWordPieceLowerCase':
        assert args.vocab_file is not None
        tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
                                            lower_case=True,
                                            vocab_extra_ids=args.vocab_extra_ids)
    elif args.tokenizer_type == 'BertWordPieceCase':
        assert args.vocab_file is not None
        tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
                                            lower_case=False,
                                            vocab_extra_ids=args.vocab_extra_ids)
    elif args.tokenizer_type == 'GPT2BPETokenizer':
        assert args.vocab_file is not None
        assert args.merge_file is not None
        tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
    elif args.tokenizer_type == 'SentencePieceTokenizer':
        assert args.tokenizer_model is not None
        tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
    elif args.tokenizer_type == 'GPTSentencePieceTokenizer':
        assert args.tokenizer_model is not None
        tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model)
    elif args.tokenizer_type == 'Llama2Tokenizer':
        assert args.tokenizer_model is not None
        tokenizer = _Llama2Tokenizer(args.tokenizer_model)
    elif args.tokenizer_type == 'Llama3Tokenizer':
        tokenizer = _Llama3Tokenizer(args.tokenizer_model)
    elif args.tokenizer_type == 'PreTrainedTokenizerFast':
        assert args.tokenizer_name_or_path is not None
        tokenizer = _PreTrainedTokenizerFast(args.tokenizer_name_or_path)
    elif args.tokenizer_type ==  'HFTokenizer':
        assert args.tokenizer_name_or_path is not None
        import logging
        if args.rank == 0:
            transformers.utils.logging.set_verbosity(logging.INFO)
        else:
            transformers.utils.logging.set_verbosity(logging.ERROR)
        if args.rank == 0:
            print(" vocab file is un-used. loading tokenizer from pre-trained model")
        tokenizer = _HFTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids)
    elif args.tokenizer_type == 'NullTokenizer':
        assert args.vocab_size is not None
        tokenizer = _NullTokenizer(args.vocab_size)
    else:
        raise NotImplementedError('{} tokenizer is not '
                                  'implemented.'.format(args.tokenizer_type))

    # Add vocab size (if not already set from a checkpoint).
    if getattr(args, "padded_vocab_size", None) is None:
        args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
                                                          args)

    return tokenizer
