|
| 1 | +From 55815e423bb82cc828836bbd60c79c1f9a195763 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Deanna Garcia <deannagarcia@google.com> |
| 3 | +Date: Tue, 13 Sep 2022 17:20:00 +0000 |
| 4 | +Subject: [PATCH] Apply patch |
| 5 | + |
| 6 | +--- |
| 7 | + src/google/protobuf/extension_set_inl.h | 27 +++-- |
| 8 | + src/google/protobuf/wire_format.cc | 26 +++-- |
| 9 | + src/google/protobuf/wire_format_lite.h | 27 +++-- |
| 10 | + src/google/protobuf/wire_format_unittest.cc | 109 ++++++++++++++++++-- |
| 11 | + 4 files changed, 152 insertions(+), 37 deletions(-) |
| 12 | + |
| 13 | +diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h |
| 14 | +index 074784b96d50..77f95f62fd58 100644 |
| 15 | +--- a/src/google/protobuf/extension_set_inl.h |
| 16 | ++++ b/src/google/protobuf/extension_set_inl.h |
| 17 | +@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( |
| 18 | + const char* ptr, const Msg* containing_type, |
| 19 | + internal::InternalMetadata* metadata, internal::ParseContext* ctx) { |
| 20 | + std::string payload; |
| 21 | +- uint32 type_id = 0; |
| 22 | +- bool payload_read = false; |
| 23 | ++ uint32 type_id; |
| 24 | ++ enum class State { kNoTag, kHasType, kHasPayload, kDone }; |
| 25 | ++ State state = State::kNoTag; |
| 26 | ++ |
| 27 | + while (!ctx->Done(&ptr)) { |
| 28 | + uint32 tag = static_cast<uint8>(*ptr++); |
| 29 | + if (tag == WireFormatLite::kMessageSetTypeIdTag) { |
| 30 | + uint64 tmp; |
| 31 | + ptr = ParseBigVarint(ptr, &tmp); |
| 32 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 33 | +- type_id = tmp; |
| 34 | +- if (payload_read) { |
| 35 | ++ if (state == State::kNoTag) { |
| 36 | ++ type_id = tmp; |
| 37 | ++ state = State::kHasType; |
| 38 | ++ } else if (state == State::kHasPayload) { |
| 39 | ++ type_id = tmp; |
| 40 | + ExtensionInfo extension; |
| 41 | + bool was_packed_on_wire; |
| 42 | + if (!FindExtension(2, type_id, containing_type, ctx, &extension, |
| 43 | +@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( |
| 44 | + GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && |
| 45 | + tmp_ctx.EndedAtLimit()); |
| 46 | + } |
| 47 | +- type_id = 0; |
| 48 | ++ state = State::kDone; |
| 49 | + } |
| 50 | + } else if (tag == WireFormatLite::kMessageSetMessageTag) { |
| 51 | +- if (type_id != 0) { |
| 52 | ++ if (state == State::kHasType) { |
| 53 | + ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr, |
| 54 | + containing_type, metadata, ctx); |
| 55 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); |
| 56 | +- type_id = 0; |
| 57 | ++ state = State::kDone; |
| 58 | + } else { |
| 59 | ++ std::string tmp; |
| 60 | + int32 size = ReadSize(&ptr); |
| 61 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 62 | +- ptr = ctx->ReadString(ptr, size, &payload); |
| 63 | ++ ptr = ctx->ReadString(ptr, size, &tmp); |
| 64 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 65 | +- payload_read = true; |
| 66 | ++ if (state == State::kNoTag) { |
| 67 | ++ payload = std::move(tmp); |
| 68 | ++ state = State::kHasPayload; |
| 69 | ++ } |
| 70 | + } |
| 71 | + } else { |
| 72 | + ptr = ReadTag(ptr - 1, &tag); |
| 73 | +diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc |
| 74 | +index c30b7abff634..382d01ea0cfb 100644 |
| 75 | +--- a/src/google/protobuf/wire_format.cc |
| 76 | ++++ b/src/google/protobuf/wire_format.cc |
| 77 | +@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser { |
| 78 | + const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { |
| 79 | + // Parse a MessageSetItem |
| 80 | + auto metadata = reflection->MutableInternalMetadata(msg); |
| 81 | ++ enum class State { kNoTag, kHasType, kHasPayload, kDone }; |
| 82 | ++ State state = State::kNoTag; |
| 83 | ++ |
| 84 | + std::string payload; |
| 85 | + uint32 type_id = 0; |
| 86 | +- bool payload_read = false; |
| 87 | + while (!ctx->Done(&ptr)) { |
| 88 | + // We use 64 bit tags in order to allow typeid's that span the whole |
| 89 | + // range of 32 bit numbers. |
| 90 | +@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser { |
| 91 | + uint64 tmp; |
| 92 | + ptr = ParseBigVarint(ptr, &tmp); |
| 93 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 94 | +- type_id = tmp; |
| 95 | +- if (payload_read) { |
| 96 | ++ if (state == State::kNoTag) { |
| 97 | ++ type_id = tmp; |
| 98 | ++ state = State::kHasType; |
| 99 | ++ } else if (state == State::kHasPayload) { |
| 100 | ++ type_id = tmp; |
| 101 | + const FieldDescriptor* field; |
| 102 | + if (ctx->data().pool == nullptr) { |
| 103 | + field = reflection->FindKnownExtensionByNumber(type_id); |
| 104 | +@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser { |
| 105 | + GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && |
| 106 | + tmp_ctx.EndedAtLimit()); |
| 107 | + } |
| 108 | +- type_id = 0; |
| 109 | ++ state = State::kDone; |
| 110 | + } |
| 111 | + continue; |
| 112 | + } else if (tag == WireFormatLite::kMessageSetMessageTag) { |
| 113 | +- if (type_id == 0) { |
| 114 | ++ if (state == State::kNoTag) { |
| 115 | + int32 size = ReadSize(&ptr); |
| 116 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 117 | + ptr = ctx->ReadString(ptr, size, &payload); |
| 118 | + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 119 | +- payload_read = true; |
| 120 | +- } else { |
| 121 | ++ state = State::kHasPayload; |
| 122 | ++ } else if (state == State::kHasType) { |
| 123 | + // We're now parsing the payload |
| 124 | + const FieldDescriptor* field = nullptr; |
| 125 | + if (descriptor->IsExtensionNumber(type_id)) { |
| 126 | +@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser { |
| 127 | + ptr = WireFormat::_InternalParseAndMergeField( |
| 128 | + msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection, |
| 129 | + field); |
| 130 | +- type_id = 0; |
| 131 | ++ state = State::kDone; |
| 132 | ++ } else { |
| 133 | ++ int32 size = ReadSize(&ptr); |
| 134 | ++ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 135 | ++ ptr = ctx->Skip(ptr, size); |
| 136 | ++ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); |
| 137 | + } |
| 138 | + } else { |
| 139 | + // An unknown field in MessageSetItem. |
| 140 | +diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h |
| 141 | +index f2a3cad82816..0b13096ccbf7 100644 |
| 142 | +--- a/src/google/protobuf/wire_format_lite.h |
| 143 | ++++ b/src/google/protobuf/wire_format_lite.h |
| 144 | +@@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { |
| 145 | + // we can parse it later. |
| 146 | + std::string message_data; |
| 147 | + |
| 148 | ++ enum class State { kNoTag, kHasType, kHasPayload, kDone }; |
| 149 | ++ State state = State::kNoTag; |
| 150 | ++ |
| 151 | + while (true) { |
| 152 | + const uint32 tag = input->ReadTagNoLastTag(); |
| 153 | + if (tag == 0) return false; |
| 154 | +@@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { |
| 155 | + case WireFormatLite::kMessageSetTypeIdTag: { |
| 156 | + uint32 type_id; |
| 157 | + if (!input->ReadVarint32(&type_id)) return false; |
| 158 | +- last_type_id = type_id; |
| 159 | +- |
| 160 | +- if (!message_data.empty()) { |
| 161 | ++ if (state == State::kNoTag) { |
| 162 | ++ last_type_id = type_id; |
| 163 | ++ state = State::kHasType; |
| 164 | ++ } else if (state == State::kHasPayload) { |
| 165 | + // We saw some message data before the type_id. Have to parse it |
| 166 | + // now. |
| 167 | + io::CodedInputStream sub_input( |
| 168 | + reinterpret_cast<const uint8*>(message_data.data()), |
| 169 | + static_cast<int>(message_data.size())); |
| 170 | + sub_input.SetRecursionLimit(input->RecursionBudget()); |
| 171 | +- if (!ms.ParseField(last_type_id, &sub_input)) { |
| 172 | ++ if (!ms.ParseField(type_id, &sub_input)) { |
| 173 | + return false; |
| 174 | + } |
| 175 | + message_data.clear(); |
| 176 | ++ state = State::kDone; |
| 177 | + } |
| 178 | + |
| 179 | + break; |
| 180 | + } |
| 181 | + |
| 182 | + case WireFormatLite::kMessageSetMessageTag: { |
| 183 | +- if (last_type_id == 0) { |
| 184 | ++ if (state == State::kHasType) { |
| 185 | ++ // Already saw type_id, so we can parse this directly. |
| 186 | ++ if (!ms.ParseField(last_type_id, input)) { |
| 187 | ++ return false; |
| 188 | ++ } |
| 189 | ++ state = State::kDone; |
| 190 | ++ } else if (state == State::kNoTag) { |
| 191 | + // We haven't seen a type_id yet. Append this data to message_data. |
| 192 | + uint32 length; |
| 193 | + if (!input->ReadVarint32(&length)) return false; |
| 194 | +@@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { |
| 195 | + auto ptr = reinterpret_cast<uint8*>(&message_data[0]); |
| 196 | + ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); |
| 197 | + if (!input->ReadRaw(ptr, length)) return false; |
| 198 | ++ state = State::kHasPayload; |
| 199 | + } else { |
| 200 | +- // Already saw type_id, so we can parse this directly. |
| 201 | +- if (!ms.ParseField(last_type_id, input)) { |
| 202 | +- return false; |
| 203 | +- } |
| 204 | ++ if (!ms.SkipField(tag, input)) return false; |
| 205 | + } |
| 206 | + |
| 207 | + break; |
| 208 | +diff --git a/src/google/protobuf/wire_format_unittest.cc b/src/google/protobuf/wire_format_unittest.cc |
| 209 | +index e75fc316f875..8d767b2833eb 100644 |
| 210 | +--- a/src/google/protobuf/wire_format_unittest.cc |
| 211 | ++++ b/src/google/protobuf/wire_format_unittest.cc |
| 212 | +@@ -46,6 +46,7 @@ |
| 213 | + #include <google/protobuf/io/zero_copy_stream_impl.h> |
| 214 | + #include <google/protobuf/io/zero_copy_stream_impl_lite.h> |
| 215 | + #include <google/protobuf/descriptor.h> |
| 216 | ++#include <google/protobuf/dynamic_message.h> |
| 217 | + #include <google/protobuf/wire_format_lite.h> |
| 218 | + #include <google/protobuf/testing/googletest.h> |
| 219 | + #include <google/protobuf/stubs/logging.h> |
| 220 | +@@ -585,30 +586,56 @@ TEST(WireFormatTest, ParseMessageSet) { |
| 221 | + EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString()); |
| 222 | + } |
| 223 | + |
| 224 | +-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { |
| 225 | ++namespace { |
| 226 | ++std::string BuildMessageSetItemStart() { |
| 227 | + std::string data; |
| 228 | + { |
| 229 | +- unittest::TestMessageSetExtension1 message; |
| 230 | +- message.set_i(123); |
| 231 | +- // Build a MessageSet manually with its message content put before its |
| 232 | +- // type_id. |
| 233 | + io::StringOutputStream output_stream(&data); |
| 234 | + io::CodedOutputStream coded_output(&output_stream); |
| 235 | + coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag); |
| 236 | ++ } |
| 237 | ++ return data; |
| 238 | ++} |
| 239 | ++std::string BuildMessageSetItemEnd() { |
| 240 | ++ std::string data; |
| 241 | ++ { |
| 242 | ++ io::StringOutputStream output_stream(&data); |
| 243 | ++ io::CodedOutputStream coded_output(&output_stream); |
| 244 | ++ coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); |
| 245 | ++ } |
| 246 | ++ return data; |
| 247 | ++} |
| 248 | ++std::string BuildMessageSetTestExtension1(int value = 123) { |
| 249 | ++ std::string data; |
| 250 | ++ { |
| 251 | ++ unittest::TestMessageSetExtension1 message; |
| 252 | ++ message.set_i(value); |
| 253 | ++ io::StringOutputStream output_stream(&data); |
| 254 | ++ io::CodedOutputStream coded_output(&output_stream); |
| 255 | + // Write the message content first. |
| 256 | + WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber, |
| 257 | + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, |
| 258 | + &coded_output); |
| 259 | + coded_output.WriteVarint32(message.ByteSizeLong()); |
| 260 | + message.SerializeWithCachedSizes(&coded_output); |
| 261 | +- // Write the type id. |
| 262 | +- uint32 type_id = message.GetDescriptor()->extension(0)->number(); |
| 263 | ++ } |
| 264 | ++ return data; |
| 265 | ++} |
| 266 | ++std::string BuildMessageSetItemTypeId(int extension_number) { |
| 267 | ++ std::string data; |
| 268 | ++ { |
| 269 | ++ io::StringOutputStream output_stream(&data); |
| 270 | ++ io::CodedOutputStream coded_output(&output_stream); |
| 271 | + WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, |
| 272 | +- type_id, &coded_output); |
| 273 | +- coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); |
| 274 | ++ extension_number, &coded_output); |
| 275 | + } |
| 276 | ++ return data; |
| 277 | ++} |
| 278 | ++void ValidateTestMessageSet(const std::string& test_case, |
| 279 | ++ const std::string& data) { |
| 280 | ++ SCOPED_TRACE(test_case); |
| 281 | + { |
| 282 | +- proto2_wireformat_unittest::TestMessageSet message_set; |
| 283 | ++ ::proto2_wireformat_unittest::TestMessageSet message_set; |
| 284 | + ASSERT_TRUE(message_set.ParseFromString(data)); |
| 285 | + |
| 286 | + EXPECT_EQ(123, |
| 287 | +@@ -616,10 +643,15 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { |
| 288 | + .GetExtension( |
| 289 | + unittest::TestMessageSetExtension1::message_set_extension) |
| 290 | + .i()); |
| 291 | ++ |
| 292 | ++ // Make sure it does not contain anything else. |
| 293 | ++ message_set.ClearExtension( |
| 294 | ++ unittest::TestMessageSetExtension1::message_set_extension); |
| 295 | ++ EXPECT_EQ(message_set.SerializeAsString(), ""); |
| 296 | + } |
| 297 | + { |
| 298 | + // Test parse the message via Reflection. |
| 299 | +- proto2_wireformat_unittest::TestMessageSet message_set; |
| 300 | ++ ::proto2_wireformat_unittest::TestMessageSet message_set; |
| 301 | + io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()), |
| 302 | + data.size()); |
| 303 | + EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set)); |
| 304 | +@@ -631,6 +663,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { |
| 305 | + unittest::TestMessageSetExtension1::message_set_extension) |
| 306 | + .i()); |
| 307 | + } |
| 308 | ++ { |
| 309 | ++ // Test parse the message via DynamicMessage. |
| 310 | ++ DynamicMessageFactory factory; |
| 311 | ++ std::unique_ptr<Message> msg( |
| 312 | ++ factory |
| 313 | ++ .GetPrototype( |
| 314 | ++ ::proto2_wireformat_unittest::TestMessageSet::descriptor()) |
| 315 | ++ ->New()); |
| 316 | ++ msg->ParseFromString(data); |
| 317 | ++ auto* reflection = msg->GetReflection(); |
| 318 | ++ std::vector<const FieldDescriptor*> fields; |
| 319 | ++ reflection->ListFields(*msg, &fields); |
| 320 | ++ ASSERT_EQ(fields.size(), 1); |
| 321 | ++ const auto& sub = reflection->GetMessage(*msg, fields[0]); |
| 322 | ++ reflection = sub.GetReflection(); |
| 323 | ++ EXPECT_EQ(123, reflection->GetInt32( |
| 324 | ++ sub, sub.GetDescriptor()->FindFieldByName("i"))); |
| 325 | ++ } |
| 326 | ++} |
| 327 | ++} // namespace |
| 328 | ++ |
| 329 | ++TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) { |
| 330 | ++ std::string start = BuildMessageSetItemStart(); |
| 331 | ++ std::string end = BuildMessageSetItemEnd(); |
| 332 | ++ std::string id = BuildMessageSetItemTypeId( |
| 333 | ++ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number()); |
| 334 | ++ std::string message = BuildMessageSetTestExtension1(); |
| 335 | ++ |
| 336 | ++ ValidateTestMessageSet("id + message", start + id + message + end); |
| 337 | ++ ValidateTestMessageSet("message + id", start + message + id + end); |
| 338 | ++} |
| 339 | ++ |
| 340 | ++TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) { |
| 341 | ++ std::string start = BuildMessageSetItemStart(); |
| 342 | ++ std::string end = BuildMessageSetItemEnd(); |
| 343 | ++ std::string id = BuildMessageSetItemTypeId( |
| 344 | ++ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number()); |
| 345 | ++ std::string other_id = BuildMessageSetItemTypeId(123456); |
| 346 | ++ std::string message = BuildMessageSetTestExtension1(); |
| 347 | ++ std::string other_message = BuildMessageSetTestExtension1(321); |
| 348 | ++ |
| 349 | ++ // Double id |
| 350 | ++ ValidateTestMessageSet("id + other_id + message", |
| 351 | ++ start + id + other_id + message + end); |
| 352 | ++ ValidateTestMessageSet("id + message + other_id", |
| 353 | ++ start + id + message + other_id + end); |
| 354 | ++ ValidateTestMessageSet("message + id + other_id", |
| 355 | ++ start + message + id + other_id + end); |
| 356 | ++ // Double message |
| 357 | ++ ValidateTestMessageSet("id + message + other_message", |
| 358 | ++ start + id + message + other_message + end); |
| 359 | ++ ValidateTestMessageSet("message + id + other_message", |
| 360 | ++ start + message + id + other_message + end); |
| 361 | ++ ValidateTestMessageSet("message + other_message + id", |
| 362 | ++ start + message + other_message + id + end); |
| 363 | + } |
| 364 | + |
| 365 | + void SerializeReverseOrder( |
0 commit comments