1 Comment
User's avatar
main's avatar

I think pure causal is conclusively better.

> but it actually adds to the calculations since the mask calculation is an extra step.

Masking saves prefill compute w/ flash attention (or, more generically, any attn impl that computes mask(q@k.T) block-wise, instead of separating matmul & mask steps). So FLOPs is definitely worse.

> There seems to be some loss of potential information and accuracy in the model

You can take the performance of encoder-based MLLMs as a proxy for the potential perf gains of noncausal attn. Because ViTs (and other modality encoders) are often bidirectional (for good reason), the use of full attention (for the cross-modal subsequence) is common.

For example, in https://arxiv.org/html/2409.03206#:~:text=Table%203%3A,Different%20Method%20Settings, it is found that a block-causal mask obtains better video understanding results than a full mask.

So it is possible that causal attention is actually beneficial for text overall.

> I suspect there are clever ways of training a decode only model and then fine tuning the prefill to work without the masking on a slightly different set of weights.

If I had to hedge, perhaps it will later be discovered that models trained on pure next-token-prediction objective will fail benefit from said fine-tuning, but models that were pretrained with a mixture-of-denoisers objective will be more amenable to your suggested approach.

Expand full comment