import sys
from typing import Optional, List, Union, Tuple

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import LlamaForCausalLM, LlamaPreTrainedModel, LlamaConfig, AutoModel
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from transformers.models.idefics.modeling_idefics import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaModel
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
import torch.distributed as dist


def _make_causal_mask(
        input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    """
    prompt type1: "{}", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>
    token ids: [376, ..., 9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871, 
                32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014]

    prompt type2: "{}", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>
    token ids: [376, ..., 9162, 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871, 32000, 
                32009, 32012, 32001, 32010, 32003, 32006, 32015]
    """
    summarize_suffix_ids = [999, 888]
    predict_suffix_ids = [666, 777]
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    expanded_mask[:, :, - len(predict_suffix_ids):,
    -len(summarize_suffix_ids) - len(predict_suffix_ids): - len(predict_suffix_ids)] = 0

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):
    # create causal mask
    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
    combined_attention_mask = None
    if input_shape[-1] > 1:
        combined_attention_mask = _make_causal_mask(
            input_shape,
            inputs_embeds.dtype,
            device=inputs_embeds.device,
            past_key_values_length=past_key_values_length,
        )

    if attention_mask is not None:
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
            inputs_embeds.device
        )
        combined_attention_mask = (
            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
        )

    # if dist.get_rank() == 0:
    #     print(combined_attention_mask[0][0][-1])

    return combined_attention_mask

input_ids = torch.tensor(
    [[1, 5, 6, 999, 888, 666, 777],
      [0, 0, 6, 999, 888, 666, 777]
    ],
    device='cuda:0'
)
attention_mask = torch.ones(input_ids.shape, device='cuda:0')
attention_mask[1][:2] = 0
inputs_embeds = torch.rand((5, 10), device='cuda:0')
a = _prepare_decoder_attention_mask(attention_mask, input_ids.shape, inputs_embeds, 0)
print(a > -1 + 1 - 1)
print(a.shape)