#! /bin/bash
set -e

CUR_DIR=$(cd `dirname $0`; pwd)
DATESTR=$(date +"%m-%d-%H-%M")

get_prop() {
    grep "${1}" ${2} | cut -d'=' -f2
}


#set dataset
MODE=${2-"ONLINE"} # [ONLINE, OFFLINE]
source $CUR_DIR/prepare_once.sh $MODE

# DDP options
GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
NNODES=1
NODE_RANK=0

# distributed options
TP_SIZE=8

# model options
LAYERS=40
HIDDEN_SIZE=6144
NUM_ATTN_HEADS=48
MICRO_BATCH_SIZE=8
GLOBAL_BATCH_SIZE=8
SEQ_LEN=2048
MAX_POSITION_EMBEDDINGS=2048

#train options
TRAIN_ITERS=${1-50}

DDP_OPTIONS="--nproc_per_node $GPUS_PER_NODE \
             --nnodes $NNODES \
             --node_rank $NODE_RANK \
             --master_addr $MASTER_ADDR \
             --master_port $MASTER_PORT"

DISTRIBUTED_OPTIONS="--distributed-backend cncl \
                     --sequence-parallel \
                     --tensor-model-parallel-size $TP_SIZE"

RECOMPUTE_OPTIONS="--recompute-granularity selective"

MODEL_OPTIONS="--num-layers $LAYERS \
               --hidden-size $HIDDEN_SIZE \
               --num-attention-heads $NUM_ATTN_HEADS \
               --micro-batch-size $MICRO_BATCH_SIZE \
               --global-batch-size $GLOBAL_BATCH_SIZE \
	       --use-mcore-models \
	       --attention-softmax-in-fp32 \
               --seq-length $SEQ_LEN \
               --max-position-embeddings $MAX_POSITION_EMBEDDINGS"

TRAIN_OPTIONS="--bf16 \
               --train-iters $TRAIN_ITERS \
               --lr-decay-iters 320000 \
               --data-path $LLAMA2_OPENWEBTEXT_DATASET_PATH \
               --vocab-file $GPT2_VOCAB_PATH \
               --merge-file $GPT2_MERGES_PATH \
               --split 949,50,1 \
               --lr 0.0006 \
               --lr-decay-style cosine \
               --min-lr 6.0e-5 \
               --weight-decay 0.1 \
	       --adam-beta1 0.9 \
	       --adam-beta2 0.95 \
               --clip-grad 1.0 \
               --lr-warmup-fraction .01 \
               --log-interval 1 \
               --save-interval 10000 \
               --eval-interval 1000 \
               --eval-iters 1"

cmd="torchrun $DDP_OPTIONS \
    pretrain_gpt.py \
    $DISTRIBUTED_OPTIONS \
    $MODEL_OPTIONS \
    $TRAIN_OPTIONS \
    $RECOMPUTE_OPTIONS"


echo $cmd
eval $cmd 2>&1 | tee $CUR_DIR/logs/log-gpt-18B-$DATESTR.txt
