Skip to content

Commit a8025ba

Browse files
[AUTO-CHERRYPICK] protobuf: patch CVE-2022-1941 - branch main (#10018)
Co-authored-by: Archana Choudhary <36061892+arc9693@users.noreply.github.com>
1 parent 76613af commit a8025ba

2 files changed

Lines changed: 371 additions & 2 deletions

File tree

SPECS/protobuf/CVE-2022-1941.patch

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
From 55815e423bb82cc828836bbd60c79c1f9a195763 Mon Sep 17 00:00:00 2001
2+
From: Deanna Garcia <deannagarcia@google.com>
3+
Date: Tue, 13 Sep 2022 17:20:00 +0000
4+
Subject: [PATCH] Apply patch
5+
6+
---
7+
src/google/protobuf/extension_set_inl.h | 27 +++--
8+
src/google/protobuf/wire_format.cc | 26 +++--
9+
src/google/protobuf/wire_format_lite.h | 27 +++--
10+
src/google/protobuf/wire_format_unittest.cc | 109 ++++++++++++++++++--
11+
4 files changed, 152 insertions(+), 37 deletions(-)
12+
13+
diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h
14+
index 074784b96d50..77f95f62fd58 100644
15+
--- a/src/google/protobuf/extension_set_inl.h
16+
+++ b/src/google/protobuf/extension_set_inl.h
17+
@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
18+
const char* ptr, const Msg* containing_type,
19+
internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
20+
std::string payload;
21+
- uint32 type_id = 0;
22+
- bool payload_read = false;
23+
+ uint32 type_id;
24+
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
25+
+ State state = State::kNoTag;
26+
+
27+
while (!ctx->Done(&ptr)) {
28+
uint32 tag = static_cast<uint8>(*ptr++);
29+
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
30+
uint64 tmp;
31+
ptr = ParseBigVarint(ptr, &tmp);
32+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
33+
- type_id = tmp;
34+
- if (payload_read) {
35+
+ if (state == State::kNoTag) {
36+
+ type_id = tmp;
37+
+ state = State::kHasType;
38+
+ } else if (state == State::kHasPayload) {
39+
+ type_id = tmp;
40+
ExtensionInfo extension;
41+
bool was_packed_on_wire;
42+
if (!FindExtension(2, type_id, containing_type, ctx, &extension,
43+
@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
44+
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
45+
tmp_ctx.EndedAtLimit());
46+
}
47+
- type_id = 0;
48+
+ state = State::kDone;
49+
}
50+
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
51+
- if (type_id != 0) {
52+
+ if (state == State::kHasType) {
53+
ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
54+
containing_type, metadata, ctx);
55+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
56+
- type_id = 0;
57+
+ state = State::kDone;
58+
} else {
59+
+ std::string tmp;
60+
int32 size = ReadSize(&ptr);
61+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
62+
- ptr = ctx->ReadString(ptr, size, &payload);
63+
+ ptr = ctx->ReadString(ptr, size, &tmp);
64+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
65+
- payload_read = true;
66+
+ if (state == State::kNoTag) {
67+
+ payload = std::move(tmp);
68+
+ state = State::kHasPayload;
69+
+ }
70+
}
71+
} else {
72+
ptr = ReadTag(ptr - 1, &tag);
73+
diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc
74+
index c30b7abff634..382d01ea0cfb 100644
75+
--- a/src/google/protobuf/wire_format.cc
76+
+++ b/src/google/protobuf/wire_format.cc
77+
@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
78+
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
79+
// Parse a MessageSetItem
80+
auto metadata = reflection->MutableInternalMetadata(msg);
81+
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
82+
+ State state = State::kNoTag;
83+
+
84+
std::string payload;
85+
uint32 type_id = 0;
86+
- bool payload_read = false;
87+
while (!ctx->Done(&ptr)) {
88+
// We use 64 bit tags in order to allow typeid's that span the whole
89+
// range of 32 bit numbers.
90+
@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
91+
uint64 tmp;
92+
ptr = ParseBigVarint(ptr, &tmp);
93+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
94+
- type_id = tmp;
95+
- if (payload_read) {
96+
+ if (state == State::kNoTag) {
97+
+ type_id = tmp;
98+
+ state = State::kHasType;
99+
+ } else if (state == State::kHasPayload) {
100+
+ type_id = tmp;
101+
const FieldDescriptor* field;
102+
if (ctx->data().pool == nullptr) {
103+
field = reflection->FindKnownExtensionByNumber(type_id);
104+
@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
105+
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
106+
tmp_ctx.EndedAtLimit());
107+
}
108+
- type_id = 0;
109+
+ state = State::kDone;
110+
}
111+
continue;
112+
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
113+
- if (type_id == 0) {
114+
+ if (state == State::kNoTag) {
115+
int32 size = ReadSize(&ptr);
116+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
117+
ptr = ctx->ReadString(ptr, size, &payload);
118+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
119+
- payload_read = true;
120+
- } else {
121+
+ state = State::kHasPayload;
122+
+ } else if (state == State::kHasType) {
123+
// We're now parsing the payload
124+
const FieldDescriptor* field = nullptr;
125+
if (descriptor->IsExtensionNumber(type_id)) {
126+
@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
127+
ptr = WireFormat::_InternalParseAndMergeField(
128+
msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection,
129+
field);
130+
- type_id = 0;
131+
+ state = State::kDone;
132+
+ } else {
133+
+ int32 size = ReadSize(&ptr);
134+
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
135+
+ ptr = ctx->Skip(ptr, size);
136+
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
137+
}
138+
} else {
139+
// An unknown field in MessageSetItem.
140+
diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h
141+
index f2a3cad82816..0b13096ccbf7 100644
142+
--- a/src/google/protobuf/wire_format_lite.h
143+
+++ b/src/google/protobuf/wire_format_lite.h
144+
@@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
145+
// we can parse it later.
146+
std::string message_data;
147+
148+
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
149+
+ State state = State::kNoTag;
150+
+
151+
while (true) {
152+
const uint32 tag = input->ReadTagNoLastTag();
153+
if (tag == 0) return false;
154+
@@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
155+
case WireFormatLite::kMessageSetTypeIdTag: {
156+
uint32 type_id;
157+
if (!input->ReadVarint32(&type_id)) return false;
158+
- last_type_id = type_id;
159+
-
160+
- if (!message_data.empty()) {
161+
+ if (state == State::kNoTag) {
162+
+ last_type_id = type_id;
163+
+ state = State::kHasType;
164+
+ } else if (state == State::kHasPayload) {
165+
// We saw some message data before the type_id. Have to parse it
166+
// now.
167+
io::CodedInputStream sub_input(
168+
reinterpret_cast<const uint8*>(message_data.data()),
169+
static_cast<int>(message_data.size()));
170+
sub_input.SetRecursionLimit(input->RecursionBudget());
171+
- if (!ms.ParseField(last_type_id, &sub_input)) {
172+
+ if (!ms.ParseField(type_id, &sub_input)) {
173+
return false;
174+
}
175+
message_data.clear();
176+
+ state = State::kDone;
177+
}
178+
179+
break;
180+
}
181+
182+
case WireFormatLite::kMessageSetMessageTag: {
183+
- if (last_type_id == 0) {
184+
+ if (state == State::kHasType) {
185+
+ // Already saw type_id, so we can parse this directly.
186+
+ if (!ms.ParseField(last_type_id, input)) {
187+
+ return false;
188+
+ }
189+
+ state = State::kDone;
190+
+ } else if (state == State::kNoTag) {
191+
// We haven't seen a type_id yet. Append this data to message_data.
192+
uint32 length;
193+
if (!input->ReadVarint32(&length)) return false;
194+
@@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
195+
auto ptr = reinterpret_cast<uint8*>(&message_data[0]);
196+
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
197+
if (!input->ReadRaw(ptr, length)) return false;
198+
+ state = State::kHasPayload;
199+
} else {
200+
- // Already saw type_id, so we can parse this directly.
201+
- if (!ms.ParseField(last_type_id, input)) {
202+
- return false;
203+
- }
204+
+ if (!ms.SkipField(tag, input)) return false;
205+
}
206+
207+
break;
208+
diff --git a/src/google/protobuf/wire_format_unittest.cc b/src/google/protobuf/wire_format_unittest.cc
209+
index e75fc316f875..8d767b2833eb 100644
210+
--- a/src/google/protobuf/wire_format_unittest.cc
211+
+++ b/src/google/protobuf/wire_format_unittest.cc
212+
@@ -46,6 +46,7 @@
213+
#include <google/protobuf/io/zero_copy_stream_impl.h>
214+
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
215+
#include <google/protobuf/descriptor.h>
216+
+#include <google/protobuf/dynamic_message.h>
217+
#include <google/protobuf/wire_format_lite.h>
218+
#include <google/protobuf/testing/googletest.h>
219+
#include <google/protobuf/stubs/logging.h>
220+
@@ -585,30 +586,56 @@ TEST(WireFormatTest, ParseMessageSet) {
221+
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
222+
}
223+
224+
-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
225+
+namespace {
226+
+std::string BuildMessageSetItemStart() {
227+
std::string data;
228+
{
229+
- unittest::TestMessageSetExtension1 message;
230+
- message.set_i(123);
231+
- // Build a MessageSet manually with its message content put before its
232+
- // type_id.
233+
io::StringOutputStream output_stream(&data);
234+
io::CodedOutputStream coded_output(&output_stream);
235+
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
236+
+ }
237+
+ return data;
238+
+}
239+
+std::string BuildMessageSetItemEnd() {
240+
+ std::string data;
241+
+ {
242+
+ io::StringOutputStream output_stream(&data);
243+
+ io::CodedOutputStream coded_output(&output_stream);
244+
+ coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
245+
+ }
246+
+ return data;
247+
+}
248+
+std::string BuildMessageSetTestExtension1(int value = 123) {
249+
+ std::string data;
250+
+ {
251+
+ unittest::TestMessageSetExtension1 message;
252+
+ message.set_i(value);
253+
+ io::StringOutputStream output_stream(&data);
254+
+ io::CodedOutputStream coded_output(&output_stream);
255+
// Write the message content first.
256+
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
257+
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
258+
&coded_output);
259+
coded_output.WriteVarint32(message.ByteSizeLong());
260+
message.SerializeWithCachedSizes(&coded_output);
261+
- // Write the type id.
262+
- uint32 type_id = message.GetDescriptor()->extension(0)->number();
263+
+ }
264+
+ return data;
265+
+}
266+
+std::string BuildMessageSetItemTypeId(int extension_number) {
267+
+ std::string data;
268+
+ {
269+
+ io::StringOutputStream output_stream(&data);
270+
+ io::CodedOutputStream coded_output(&output_stream);
271+
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
272+
- type_id, &coded_output);
273+
- coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
274+
+ extension_number, &coded_output);
275+
}
276+
+ return data;
277+
+}
278+
+void ValidateTestMessageSet(const std::string& test_case,
279+
+ const std::string& data) {
280+
+ SCOPED_TRACE(test_case);
281+
{
282+
- proto2_wireformat_unittest::TestMessageSet message_set;
283+
+ ::proto2_wireformat_unittest::TestMessageSet message_set;
284+
ASSERT_TRUE(message_set.ParseFromString(data));
285+
286+
EXPECT_EQ(123,
287+
@@ -616,10 +643,15 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
288+
.GetExtension(
289+
unittest::TestMessageSetExtension1::message_set_extension)
290+
.i());
291+
+
292+
+ // Make sure it does not contain anything else.
293+
+ message_set.ClearExtension(
294+
+ unittest::TestMessageSetExtension1::message_set_extension);
295+
+ EXPECT_EQ(message_set.SerializeAsString(), "");
296+
}
297+
{
298+
// Test parse the message via Reflection.
299+
- proto2_wireformat_unittest::TestMessageSet message_set;
300+
+ ::proto2_wireformat_unittest::TestMessageSet message_set;
301+
io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()),
302+
data.size());
303+
EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set));
304+
@@ -631,6 +663,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
305+
unittest::TestMessageSetExtension1::message_set_extension)
306+
.i());
307+
}
308+
+ {
309+
+ // Test parse the message via DynamicMessage.
310+
+ DynamicMessageFactory factory;
311+
+ std::unique_ptr<Message> msg(
312+
+ factory
313+
+ .GetPrototype(
314+
+ ::proto2_wireformat_unittest::TestMessageSet::descriptor())
315+
+ ->New());
316+
+ msg->ParseFromString(data);
317+
+ auto* reflection = msg->GetReflection();
318+
+ std::vector<const FieldDescriptor*> fields;
319+
+ reflection->ListFields(*msg, &fields);
320+
+ ASSERT_EQ(fields.size(), 1);
321+
+ const auto& sub = reflection->GetMessage(*msg, fields[0]);
322+
+ reflection = sub.GetReflection();
323+
+ EXPECT_EQ(123, reflection->GetInt32(
324+
+ sub, sub.GetDescriptor()->FindFieldByName("i")));
325+
+ }
326+
+}
327+
+} // namespace
328+
+
329+
+TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
330+
+ std::string start = BuildMessageSetItemStart();
331+
+ std::string end = BuildMessageSetItemEnd();
332+
+ std::string id = BuildMessageSetItemTypeId(
333+
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
334+
+ std::string message = BuildMessageSetTestExtension1();
335+
+
336+
+ ValidateTestMessageSet("id + message", start + id + message + end);
337+
+ ValidateTestMessageSet("message + id", start + message + id + end);
338+
+}
339+
+
340+
+TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
341+
+ std::string start = BuildMessageSetItemStart();
342+
+ std::string end = BuildMessageSetItemEnd();
343+
+ std::string id = BuildMessageSetItemTypeId(
344+
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
345+
+ std::string other_id = BuildMessageSetItemTypeId(123456);
346+
+ std::string message = BuildMessageSetTestExtension1();
347+
+ std::string other_message = BuildMessageSetTestExtension1(321);
348+
+
349+
+ // Double id
350+
+ ValidateTestMessageSet("id + other_id + message",
351+
+ start + id + other_id + message + end);
352+
+ ValidateTestMessageSet("id + message + other_id",
353+
+ start + id + message + other_id + end);
354+
+ ValidateTestMessageSet("message + id + other_id",
355+
+ start + message + id + other_id + end);
356+
+ // Double message
357+
+ ValidateTestMessageSet("id + message + other_message",
358+
+ start + id + message + other_message + end);
359+
+ ValidateTestMessageSet("message + id + other_message",
360+
+ start + message + id + other_message + end);
361+
+ ValidateTestMessageSet("message + other_message + id",
362+
+ start + message + other_message + id + end);
363+
}
364+
365+
void SerializeReverseOrder(

0 commit comments

Comments
 (0)