import numpy as np


def apply_decoder(
    parameters: dict,
    token_ids: np.ndarray,
    weight_sharing: bool = False,
) -> np.ndarray:
    """
    Parameters
    ----------
    - `parameters`: A dictionary with the following keys:
        - `key_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
        - `output_parameters`: A 2D numpy array of shape `(value_dim, embedding_dim)`
        - `query_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
        - `token_embeddings`: A 2D numpy array of shape `(vocab_size, embedding_dim)`
        - `unembedding`: A 2D numpy array of shape `(embedding_dim, vocab_size)`. If `weight_sharing` is `True`, this parameter is ignored and the transpose of `token_embeddings` is used instead.
        - `value_parameters`: A 2D numpy array of shape `(embedding_dim, value_dim)`

    - `token_ids`: A 1D numpy array of shape `(batch_dim, sequence_length,)`
    - `weight_sharing`: A boolean indicating whether to use weight sharing between the token embeddings and the unembedding matrix. Defaults to `False`.

    Returns
    -------
    - A numpy array `logits` of shape `(batch_dim, sequence_length, vocab_size,)`
        representing the logits for the next token prediction at each position in the sequence.
    """
    key_parameters = parameters['key_parameters']
    output_parameters = parameters['output_parameters']
    query_parameters = parameters['query_parameters']
    token_embeddings = parameters['token_embeddings']
    if weight_sharing:
        unembedding = token_embeddings.T
    else:
        unembedding = parameters['unembedding']

    value_parameters = parameters['value_parameters']

    embeddings = get_embedding(token_embeddings, token_ids)
    keys, queries, values = get_kqv(
        key_parameters,
        query_parameters,
        embeddings,
        value_parameters,
    )
    attention_logits = get_attention_logits(keys, queries)
    attention_weights = get_attention_weights(attention_logits)
    aggregated_values = attention_weights @ values
    features = embeddings + aggregated_values @ output_parameters

    logits = features @ unembedding

    return logits


def get_attention_logits(
    keys: np.ndarray,
    queries: np.ndarray,
    causal=True
) -> np.ndarray:
    """
    Parameters
    ----------
    - `keys`: A 2D numpy array of shape `(batch_dim, sequence_length, key_dim)`
    - `queries`: A 2D numpy array of shape `(batch_dim, sequence_length, key_dim)`
    - `causal`: A boolean indicating whether to apply a causal mask, that is to set attention logits at positions `(i, j)` to `-inf` for all `j > i`. Defaults to `True`.

    Returns
    -------
    - A 2D numpy array `logits` of shape `(batch_dim, sequence_length, sequence_length)`
        such that `logits[i, j]` is the dot product of the `i`-th row of `queries`
        and the `j`-th row of `keys`, divided by the square root of `key_dim`.
    """
    key_dim = keys.shape[-1]
    logits = queries @ np.swapaxes(keys, -2, -1) * key_dim ** -.5
    if causal:
        sequence_length = keys.shape[-2]
        sequence_arange = np.arange(sequence_length)
        causal_mask = sequence_arange[None, :] <= sequence_arange[:, None]
        logits = np.where(causal_mask, logits, -np.inf)

    return logits


def get_attention_weights(attention_logits: np.ndarray) -> np.ndarray:
    """
    Parameters
    ----------
    - `attention_logits`: A 2D numpy array of shape `(sequence_length, sequence_length)`

    Returns
    -------
    - A 2D numpy array `attention_weights` of shape `(sequence_length, sequence_length)`
        such that `attention_weights[i]` is the softmax of the `i`-th row of `attention_logits`. That is:
        1. we subtract the maximum value of the `i`-th row of `attention_logits` from each element of that row,
        to avoid overflow in the next step:
        2. we exponentiate each element of the `i`-th row of `attention_logits`, and
        3. we divide each element of the `i`-th row of `attention_logits` by the sum of the exponentiated elements of that row.
    """
    attention_logits = attention_logits - np.max(attention_logits, axis=-1, keepdims=True)
    weights = np.exp(attention_logits)
    weights /= np.sum(weights, axis=-1, keepdims=True)

    return weights


def get_embedding(
    token_embeddings: np.ndarray,
    token_ids: np.ndarray,
) -> np.ndarray:
    """
    Parameters
    ----------
    - `token_embeddings`: A 2D numpy array of shape `(vocab_size, embedding_dim)`
    - `token_ids`: A 2D numpy array of shape `(batch_dim, sequence_length,)`

    Returns
    -------
    - A 2D numpy array `features` of shape `(batch_dim, sequence_length, embedding_dim)`
    such that its `i`-th row is the embedding of the token with id `token_ids[i]`.
    """
    features = token_embeddings[token_ids]

    return features



def get_kqv(
    key_parameters: np.ndarray,
    query_parameters: np.ndarray,
    residual_input: np.ndarray,
    value_parameters: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Parameters
    ----------
    - `key_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
    - `query_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
    - `residual_input`: A 2D numpy array of shape `(sequence_length, key_dim)`
    - `value_parameters`: A 2D numpy array of shape `(embedding_dim, value_dim)`

    Returns
    -------
    A tuple of three 2D numpy arrays:
    - `keys`: The key matrix of shape `(sequence_length, key_dim)`.
        It is the matrix product of `residual_input` and `key_parameters`.
    - `queries`: The query matrix of shape `(sequence_length, key_dim)`
        It is the matrix product of `residual_input` and `query_parameters`.
    - `values`: The value matrix of shape `(sequence_length, value_dim)`
        It is the matrix product of `residual_input` and `value_parameters`.
    """
    keys = residual_input @ key_parameters
    queries = residual_input @ query_parameters
    values = residual_input @ value_parameters

    return keys, queries, values