File tree Expand file tree Collapse file tree
FlagEmbedding/baai_general_embedding/retromae_pretrain Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -53,12 +53,12 @@ def forward(self,
5353 decoder_embedding_output = self .decoder_embeddings (input_ids = decoder_input_ids )
5454 hiddens = torch .cat ([cls_hiddens , decoder_embedding_output [:, 1 :]], dim = 1 )
5555
56- decoder_position_ids = self .lm .bert .embeddings .position_ids [:, :decoder_input_ids .size (1 )]
57- decoder_position_embeddings = self .lm .bert .embeddings .position_embeddings (decoder_position_ids ) # B L D
58- query = decoder_position_embeddings + cls_hiddens
56+ # decoder_position_ids = self.lm.bert.embeddings.position_ids[:, :decoder_input_ids.size(1)]
57+ # decoder_position_embeddings = self.lm.bert.embeddings.position_embeddings(decoder_position_ids) # B L D
58+ # query = decoder_position_embeddings + cls_hiddens
5959
60- # cls_hiddens = cls_hiddens.expand(hiddens.size(0), hiddens.size(1), hiddens.size(2))
61- # query = self.decoder_embeddings(inputs_embeds=cls_hiddens)
60+ cls_hiddens = cls_hiddens .expand (hiddens .size (0 ), hiddens .size (1 ), hiddens .size (2 ))
61+ query = self .decoder_embeddings (inputs_embeds = cls_hiddens )
6262
6363 matrix_attention_mask = self .lm .get_extended_attention_mask (
6464 decoder_attention_mask ,
You can’t perform that action at this time.
0 commit comments