Skip to content

Commit 6fd68fe

Browse files
committed
chore: change ratio for hybrid dbfs
1 parent 630223c commit 6fd68fe

1 file changed

Lines changed: 79 additions & 20 deletions

File tree

examples/inference/embedder/encoder_only/m3_single_device_ensemble.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@
55

66

77
def pad_colbert_vecs(colbert_vecs_list, device):
8+
"""
9+
Since ColBERT embeddings are computed on a token-level basis, each document (or query)
10+
may produce a different number of token embeddings. This function aligns all embeddings
11+
to the same length by padding shorter sequences with zeros, ensuring that every input
12+
ends up with a uniform shape.
13+
14+
Steps:
15+
1. Determine the maximum sequence length (i.e., the largest number of tokens in any
16+
query or passage within the batch).
17+
2. For each set of token embeddings, pad it with zeros until it matches the max
18+
sequence length. Zeros here act as placeholders and do not affect the similarity
19+
computations since they represent "no token."
20+
3. Convert all padded embeddings into a single, consistent tensor and move it to the
21+
specified device (e.g., GPU) for efficient batch computation.
22+
23+
By performing this padding operation, subsequent tensor operations (like the einsum
24+
computations for ColBERT scoring) become simpler and more efficient, as all sequences
25+
share a common shape.
26+
"""
27+
828
lengths = [vec.shape[0] for vec in colbert_vecs_list]
929
max_len = max(lengths)
1030
dim = colbert_vecs_list[0].shape[1]
@@ -18,18 +38,57 @@ def pad_colbert_vecs(colbert_vecs_list, device):
1838

1939

2040
def compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs):
21-
# query_colbert_vecs: (Q, Tq, D)
22-
# passage_colbert_vecs: (P, Tp, D)
23-
# einsum 식에서 q:queries, p:passages, r:query tokens dim, c:passage tokens dim, d:embedding dim
41+
"""
42+
Compute ColBERT scores:
43+
44+
ColBERT (Contextualized Late Interaction over BERT) evaluates the similarity
45+
between a query and a passage at the token level. Instead of producing a single
46+
dense vector for each query or passage, ColBERT maintains embeddings for every
47+
token. This allows for finer-grained matching, capturing more subtle similarities.
48+
49+
Definitions of variables:
50+
- q: Number of queries (Q)
51+
- p: Number of passages (P)
52+
- r: Number of tokens in each query (Tq)
53+
- c: Number of tokens in each passage (Tp)
54+
- d: Embedding dimension (D)
55+
56+
I used the operation `einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs)`:
57+
- einsum (Einstein summation) is a powerful notation and function for
58+
expressing and computing multi-dimensional tensor contractions. It allows you
59+
to specify how dimensions in input tensors correspond to each other and how
60+
they should be combined (multiplied and summed) to produce the output.
61+
62+
In this particular case:
63+
- "qrd" corresponds to (Q, Tq, D) for query token embeddings.
64+
- "pcd" corresponds to (P, Tp, D) for passage token embeddings.
65+
- "qrd,pcd->qprc" means:
66+
1. For each query q and passage p, compute the dot product between every query token
67+
embedding (r) and every passage token embedding (c) across the embedding dimension d.
68+
2. This results in a (Q, P, Tq, Tp) tensor (qprc), where each element is the similarity
69+
score between a single query token and a single passage token.
70+
71+
After computing this full matrix of token-to-token scores:
72+
- We take the maximum over the passage token dimension (c) for each query token (r).
73+
This step identifies, for each query token, which passage token is the "best match."
74+
- Then we sum over all query tokens (r) to aggregate their best matches into a single
75+
score per query-passage pair.
76+
77+
In summary:
78+
1. einsum to get all pairwise token similarities.
79+
2. max over passage tokens to find the best matching passage token for each query token.
80+
3. sum over query tokens to combine all the best matches into a final ColBERT score
81+
for each query-passage pair.
82+
"""
83+
2484
dot_products = torch.einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs) # Q,P,Tq,Tp
25-
max_per_query_token, _ = dot_products.max(dim=3) # max over c (Tp)
26-
colbert_scores = max_per_query_token.sum(dim=2) # sum over r (Tq)
85+
max_per_query_token, _ = dot_products.max(dim=3) # 문서 토큰 축(Tp)에 대해 max
86+
colbert_scores = max_per_query_token.sum(dim=2) # 쿼리 토큰 축(Tq)에 대해 합
2787
return colbert_scores
2888

2989

30-
def hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores, weights=(0.33, 0.33, 0.34)):
90+
def hybrid_dbfs_ensemble_simple_linear_combination(dense_scores, sparse_scores, colbert_scores, weights=(0.45, 0.45, 0.1)):
3191
w_dense, w_sparse, w_colbert = weights
32-
# 모든 입력이 torch.Tensor일 경우 아래 연산 정상 작동
3392
return w_dense * dense_scores + w_sparse * sparse_scores + w_colbert * colbert_scores
3493

3594

@@ -42,12 +101,12 @@ def test_m3_single_device():
42101
)
43102

44103
queries = [
45-
"What is BGE M3?",
46-
"Defination of BM25"
104+
"What is Sionic AI?",
105+
"Try https://sionicstorm.ai today!"
47106
] * 100
48107
passages = [
49-
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
50-
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
108+
"Sionic AI delivers more accessible and cost-effective AI technology addressing the various needs to boost productivity and drive innovation.",
109+
"The Large Language Model (LLM) is not for research and experimentation. We offer solutions that leverage LLM to add value to your business. Anyone can easily train and control AI."
51110
] * 100
52111

53112
queries_embeddings = model.encode_queries(
@@ -56,36 +115,33 @@ def test_m3_single_device():
56115
return_sparse=True,
57116
return_colbert_vecs=True,
58117
)
118+
59119
passages_embeddings = model.encode_corpus(
60120
passages,
61121
return_dense=True,
62122
return_sparse=True,
63123
return_colbert_vecs=True,
64124
)
65125

66-
# device 설정
67126
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68127

69-
# dense_vecs, lexical_weights 등이 numpy array 형태일 수 있으므로 텐서로 변환
128+
# Dense 스코어 계산
70129
q_dense = torch.tensor(queries_embeddings["dense_vecs"], dtype=torch.float, device=device)
71130
p_dense = torch.tensor(passages_embeddings["dense_vecs"], dtype=torch.float, device=device)
72131
dense_scores = q_dense @ p_dense.T
73132

74-
# sparse_scores도 numpy array를 텐서로 변환
75133
sparse_scores_np = model.compute_lexical_matching_score(
76134
queries_embeddings["lexical_weights"],
77135
passages_embeddings["lexical_weights"]
78136
)
137+
79138
sparse_scores = torch.tensor(sparse_scores_np, dtype=torch.float, device=device)
80139

81-
# colbert_vecs 패딩 후 텐서 변환
82140
query_colbert_vecs = pad_colbert_vecs(queries_embeddings["colbert_vecs"], device)
83141
passage_colbert_vecs = pad_colbert_vecs(passages_embeddings["colbert_vecs"], device)
84-
85142
colbert_scores = compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs)
86143

87-
# 모든 스코어가 torch.Tensor이므로 오류 없이 연산 가능
88-
hybrid_scores = hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores)
144+
hybrid_scores = hybrid_dbfs_ensemble_simple_linear_combination(dense_scores, sparse_scores, colbert_scores)
89145

90146
print("Dense score:\n", dense_scores[:2, :2])
91147
print("Sparse score:\n", sparse_scores[:2, :2])
@@ -95,11 +151,14 @@ def test_m3_single_device():
95151

96152
if __name__ == '__main__':
97153
test_m3_single_device()
154+
print("Expected Vector Scores")
98155
print("--------------------------------")
99-
print("Expected Output for Dense & Sparse (original):")
100156
print("Dense score:")
101157
print(" [[0.626 0.3477]\n [0.3496 0.678 ]]")
102158
print("Sparse score:")
103159
print(" [[0.19554901 0.00880432]\n [0. 0.18036556]]")
160+
print("ColBERT score:")
161+
print("[[5.8061, 3.1195] \n [5.6822, 4.6513]]")
162+
print("Hybrid DBSF Ensemble score:")
163+
print("[[0.9822, 0.5125] \n [0.8127, 0.6958]]")
104164
print("--------------------------------")
105-
print("ColBERT and Hybrid DBSF scores will vary depending on the actual embeddings.")

0 commit comments

Comments
 (0)