Skip to content

Commit 6ae1ddd

Browse files
committed
support roberta pretrain
1 parent af6ee2d commit 6ae1ddd

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

  • FlagEmbedding/baai_general_embedding/retromae_pretrain

FlagEmbedding/baai_general_embedding/retromae_pretrain/modeling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ def forward(self,
5353
decoder_embedding_output = self.decoder_embeddings(input_ids=decoder_input_ids)
5454
hiddens = torch.cat([cls_hiddens, decoder_embedding_output[:, 1:]], dim=1)
5555

56-
decoder_position_ids = self.lm.bert.embeddings.position_ids[:, :decoder_input_ids.size(1)]
57-
decoder_position_embeddings = self.lm.bert.embeddings.position_embeddings(decoder_position_ids) # B L D
58-
query = decoder_position_embeddings + cls_hiddens
56+
# decoder_position_ids = self.lm.bert.embeddings.position_ids[:, :decoder_input_ids.size(1)]
57+
# decoder_position_embeddings = self.lm.bert.embeddings.position_embeddings(decoder_position_ids) # B L D
58+
# query = decoder_position_embeddings + cls_hiddens
5959

60-
# cls_hiddens = cls_hiddens.expand(hiddens.size(0), hiddens.size(1), hiddens.size(2))
61-
# query = self.decoder_embeddings(inputs_embeds=cls_hiddens)
60+
cls_hiddens = cls_hiddens.expand(hiddens.size(0), hiddens.size(1), hiddens.size(2))
61+
query = self.decoder_embeddings(inputs_embeds=cls_hiddens)
6262

6363
matrix_attention_mask = self.lm.get_extended_attention_mask(
6464
decoder_attention_mask,

0 commit comments

Comments
 (0)