#! /bin/bash
set -e

CUR_DIR=$(cd `dirname $0`; pwd)
DATESTR=$(date +"%m-%d-%H-%M")

#set dataset
LLAMA_TOKENIZER_MODEL=${1-'/path-to-llama-tokenizer'}
DATA_PATH=${2-'/path-to-dataset'}

# DDP options
GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=29600
NNODES=1
NODE_RANK=0

# distributed options
TP_SIZE=8
PP_SIZE=1
 
# model options
NUM_LAYERS=32
HIDDEN_SIZE=4096
NUM_HEADS=32
MICRO_BATCH_SIZE=4
GLOBAL_BATCH_SIZE=8
SEQ_LENGTH=2048
FFN_HIDDEN_SIZE=11008
TRAIN_STEPS=100
LR=1e-6
MIN_LR=1e-8
LR_WARMUP_STEPS=1
WEIGHT_DECAY=0.1
GRAD_CLIP=1

OP_FUSION_OPTIONS="--no-persist-layer-norm \
                   --no-bias-gelu-fusion"
  
LLAMA_OPTIONS="--attention-dropout 0 \
               --hidden-dropout 0 \
               --use-rotary-position-embeddings \
               --untie-embeddings-and-output-weights \
               --swiglu \
               --normalization RMSNorm \
               --disable-bias-linear"
                                 
MODEL_OPTIONS="--num-layers $NUM_LAYERS \
               --hidden-size $HIDDEN_SIZE \
               --ffn-hidden-size $FFN_HIDDEN_SIZE \
               --num-attention-heads $NUM_HEADS \
               --seq-length $SEQ_LENGTH \
               --max-position-embeddings $SEQ_LENGTH \
               --micro-batch-size $MICRO_BATCH_SIZE \
               --global-batch-size $GLOBAL_BATCH_SIZE \
	       --use-mcore-models \
	       --attention-softmax-in-fp32 \
               --make-vocab-size-divisible-by 100"
                 
TRAINING_OPTIONS="--lr $LR \
                  --lr-decay-style cosine \
                  --min-lr $MIN_LR \
                  --adam-beta1 0.9 \
                  --adam-beta2 0.95 \
                  --adam-eps 1e-8 \
                  --weight-decay $WEIGHT_DECAY \
                  --clip-grad $GRAD_CLIP \
                  --lr-warmup-iters $LR_WARMUP_STEPS"
                 
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"
  


cmd="torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
    $MODEL_OPTIONS \
    $OP_FUSION_OPTIONS \
    $TRAINING_OPTIONS \
    $LLAMA_OPTIONS \
    --tensor-model-parallel-size $TP_SIZE \
    --pipeline-model-parallel-size $PP_SIZE \
    --sequence-parallel \
    --train-iters $TRAIN_STEPS \
    --data-path $DATA_PATH \
    --tokenizer-type Llama2Tokenizer \
    --tokenizer-model $LLAMA_TOKENIZER_MODEL \
    --split 949,50,1 \
    --distributed-backend cncl \
    --log-interval 1 \
    --save-interval 10000 \
    --eval-interval 1000 \
    --no-load-optim \
    --no-load-rng \
    --position-embedding-type rope \
    --no-position-embedding \
    --use-flash-attn \
    --eval-iters 1 \
    --bf16"

echo $cmd
eval $cmd 2>&1 | tee $CUR_DIR/log-llama2-7B-$DATESTR.txt
