from megatron.addons.function_wrapper.function_wrapper import FunctionWrapper
from megatron.training.arguments import parse_args

import AA

function_wrapper = FunctionWrapper()

def extra_args_test(parser):
    group = parser.add_argument_group(title='test_func_wrapper')

    group.add_argument('--test-1', type=str,
                       help='test1.')
    group.add_argument('--test-2', type=int, default=None,
                       help='test2.')
    return parser

def test_func_wrapper():
    @function_wrapper.register_wrapper(func_name='megatron.training.arguments.parse_args')
    def _extra_argument(extra_args_provider=None, ignore_unknown_args=False):
        args_func = function_wrapper.get_orig_func('megatron.training.arguments.parse_args')
        args = args_func(extra_args_provider=extra_args_test)
        return args

    @function_wrapper.register_wrapper(func_name='AA.A.aa')
    def _aa(self, x, y):
        orig_func = function_wrapper.get_orig_func('AA.A.aa')
        z = orig_func(self, x, y)
        return 10 * z
    
    @function_wrapper.register_wrapper(func_name='AA.B.bb')
    def _bb(self, x, y):
        orig_func = function_wrapper.get_orig_func('AA.B.bb')
        z = orig_func(self, x, y)
        return 6 * z
    
    @function_wrapper.register_wrapper(func_name='AA.B.aa')
    def _aa(self, x, y):
        orig_func = function_wrapper.get_orig_func('AA.B.aa')
        z = orig_func(self, x, y)
        return 100 * z

    @function_wrapper.register_wrapper(func_name='AA.C.cc')
    def _cc(self, x, y):
        orig_func = function_wrapper.get_orig_func('AA.C.cc')
        z = orig_func(self, x, y)
        return 18 * z
    
    @function_wrapper.register_wrapper(func_name='AA.C.aa')
    def _aa(self, x, y):
        orig_func = function_wrapper.get_orig_func('AA.C.aa')
        z = orig_func(self, x, y)
        return 18 * z
    
    function_wrapper.func_wrapper_apply()
    
    a = AA.A()
    assert a.aa(2, 3) == -10
    
    b = AA.B()
    assert b.bb(2, 3) == 30
    
    b1 = AA.B()
    assert b1.aa(1, 2) == -100
    
    c = AA.C()
    assert c.cc(4, 2) == 36
    assert c.aa(4, 2) == 144

    assert 'test_1' in dir(parse_args())
    assert 'test_2' in dir(parse_args())
