# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.
import torch

class DummyProfile:
    def __init__(self, *args, **kwargs) -> None:
        pass
 
    def __enter__(self):
        return self

    def __exit__(self, *args):
        pass

    def step(self):
        pass


def initialize_profile(args):
    if args.torch_profile:
        # The profiler will skip the first skip_first steps, then wait for wait steps,
        # then do the warmup for the next warmup steps, then do the active recording 
        # for the next active steps and then repeat the cycle starting with wait steps.
        # The optional number of cycles is specified with the repeat parameter, the zero
        # value means that the cycles will continue until the profiling is finished.
        warmup_steps = 1
        start_step = args.profile_step_start
        wait_steps = max(start_step-warmup_steps, 0)
        end_step = args.profile_step_end
        active_steps = int(end_step) - int(start_step)
        return torch.profiler.profile(activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.MLU],
            schedule=torch.profiler.schedule(wait=wait_steps, warmup=warmup_steps, active=active_steps, repeat=0, skip_first=0),
            record_shapes=True,
            with_stack=False,
            with_flops=True,
            on_trace_ready=torch.profiler.tensorboard_trace_handler(args.profile_path)
            )
    else:
        return DummyProfile()
