summaryrefslogtreecommitdiff
path: root/lib/libcppbor/cppbor.cpp
diff options
context:
space:
mode:
authorjacqueline <me@jacqueline.id.au>2023-10-13 15:05:49 +1100
committerjacqueline <me@jacqueline.id.au>2023-10-13 15:05:49 +1100
commitafbf3c31f4d1a605c264f719531f4183ee5a3022 (patch)
tree43c75029ff6dfe3e44137f6b3d0de3498f247bf2 /lib/libcppbor/cppbor.cpp
parent20d1c280a77eadcea18438453dc37daaf1d85e2d (diff)
downloadtangara-fw-afbf3c31f4d1a605c264f719531f4183ee5a3022.tar.gz
Use libcppbor for much much nicer db encoding
Diffstat (limited to 'lib/libcppbor/cppbor.cpp')
-rw-r--r--lib/libcppbor/cppbor.cpp599
1 files changed, 599 insertions, 0 deletions
diff --git a/lib/libcppbor/cppbor.cpp b/lib/libcppbor/cppbor.cpp
new file mode 100644
index 00000000..c66ca752
--- /dev/null
+++ b/lib/libcppbor/cppbor.cpp
@@ -0,0 +1,599 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "cppbor.h"
+
+#include <inttypes.h>
+#include <cstdint>
+
+#include "cppbor_parse.h"
+
+using std::string;
+using std::vector;
+
+#define CHECK(x) (void)(x)
+
+namespace cppbor {
+
+namespace {
+
+template <typename T, typename Iterator, typename = std::enable_if<std::is_unsigned<T>::value>>
+Iterator writeBigEndian(T value, Iterator pos) {
+ for (unsigned i = 0; i < sizeof(value); ++i) {
+ *pos++ = static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1)));
+ value = static_cast<T>(value << 8);
+ }
+ return pos;
+}
+
+template <typename T, typename = std::enable_if<std::is_unsigned<T>::value>>
+void writeBigEndian(T value, std::function<void(uint8_t)>& cb) {
+ for (unsigned i = 0; i < sizeof(value); ++i) {
+ cb(static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1))));
+ value = static_cast<T>(value << 8);
+ }
+}
+
+bool cborAreAllElementsNonCompound(const Item* compoundItem) {
+ if (compoundItem->type() == ARRAY) {
+ const Array* array = compoundItem->asArray();
+ for (size_t n = 0; n < array->size(); n++) {
+ const Item* entry = (*array)[n].get();
+ switch (entry->type()) {
+ case ARRAY:
+ case MAP:
+ return false;
+ default:
+ break;
+ }
+ }
+ } else {
+ const Map* map = compoundItem->asMap();
+ for (auto& [keyEntry, valueEntry] : *map) {
+ switch (keyEntry->type()) {
+ case ARRAY:
+ case MAP:
+ return false;
+ default:
+ break;
+ }
+ switch (valueEntry->type()) {
+ case ARRAY:
+ case MAP:
+ return false;
+ default:
+ break;
+ }
+ }
+ }
+ return true;
+}
+
+bool prettyPrintInternal(const Item* item, string& out, size_t indent, size_t maxBStrSize,
+ const vector<string>& mapKeysToNotPrint) {
+ if (!item) {
+ out.append("<NULL>");
+ return false;
+ }
+
+ char buf[80];
+
+ string indentString(indent, ' ');
+
+ size_t tagCount = item->semanticTagCount();
+ while (tagCount > 0) {
+ --tagCount;
+ snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", item->semanticTag(tagCount));
+ out.append(buf);
+ }
+
+ switch (item->type()) {
+ case SEMANTIC:
+ // Handled above.
+ break;
+
+ case UINT:
+ snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue());
+ out.append(buf);
+ break;
+
+ case NINT:
+ snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value());
+ out.append(buf);
+ break;
+
+ case BSTR: {
+ const uint8_t* valueData;
+ size_t valueSize;
+ const Bstr* bstr = item->asBstr();
+ if (bstr != nullptr) {
+ const vector<uint8_t>& value = bstr->value();
+ valueData = value.data();
+ valueSize = value.size();
+ } else {
+ const ViewBstr* viewBstr = item->asViewBstr();
+ assert(viewBstr != nullptr);
+
+ valueData = viewBstr->view().data();
+ valueSize = viewBstr->view().size();
+ }
+
+ if (valueSize > maxBStrSize) {
+ snprintf(buf, sizeof(buf), "<bstr size=%zd>", valueSize);
+ out.append(buf);
+ } else {
+ out.append("{");
+ for (size_t n = 0; n < valueSize; n++) {
+ if (n > 0) {
+ out.append(", ");
+ }
+ snprintf(buf, sizeof(buf), "0x%02x", valueData[n]);
+ out.append(buf);
+ }
+ out.append("}");
+ }
+ } break;
+
+ case TSTR:
+ out.append("'");
+ {
+ // TODO: escape "'" characters
+ if (item->asTstr() != nullptr) {
+ out.append(item->asTstr()->value().c_str());
+ } else {
+ const ViewTstr* viewTstr = item->asViewTstr();
+ assert(viewTstr != nullptr);
+ out.append(viewTstr->view());
+ }
+ }
+ out.append("'");
+ break;
+
+ case ARRAY: {
+ const Array* array = item->asArray();
+ if (array->size() == 0) {
+ out.append("[]");
+ } else if (cborAreAllElementsNonCompound(array)) {
+ out.append("[");
+ for (size_t n = 0; n < array->size(); n++) {
+ if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
+ mapKeysToNotPrint)) {
+ return false;
+ }
+ out.append(", ");
+ }
+ out.append("]");
+ } else {
+ out.append("[\n" + indentString);
+ for (size_t n = 0; n < array->size(); n++) {
+ out.append(" ");
+ if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
+ mapKeysToNotPrint)) {
+ return false;
+ }
+ out.append(",\n" + indentString);
+ }
+ out.append("]");
+ }
+ } break;
+
+ case MAP: {
+ const Map* map = item->asMap();
+
+ if (map->size() == 0) {
+ out.append("{}");
+ } else {
+ out.append("{\n" + indentString);
+ for (auto& [map_key, map_value] : *map) {
+ out.append(" ");
+
+ if (!prettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize,
+ mapKeysToNotPrint)) {
+ return false;
+ }
+ out.append(" : ");
+ if (map_key->type() == TSTR &&
+ std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(),
+ map_key->asTstr()->value()) != mapKeysToNotPrint.end()) {
+ out.append("<not printed>");
+ } else {
+ if (!prettyPrintInternal(map_value.get(), out, indent + 2, maxBStrSize,
+ mapKeysToNotPrint)) {
+ return false;
+ }
+ }
+ out.append(",\n" + indentString);
+ }
+ out.append("}");
+ }
+ } break;
+
+ case SIMPLE:
+ const Bool* asBool = item->asSimple()->asBool();
+ const Null* asNull = item->asSimple()->asNull();
+ if (asBool != nullptr) {
+ out.append(asBool->value() ? "true" : "false");
+ } else if (asNull != nullptr) {
+ out.append("null");
+ } else {
+ return false;
+ }
+ break;
+ }
+
+ return true;
+}
+
+} // namespace
+
+size_t headerSize(uint64_t addlInfo) {
+ if (addlInfo < ONE_BYTE_LENGTH) return 1;
+ if (addlInfo <= std::numeric_limits<uint8_t>::max()) return 2;
+ if (addlInfo <= std::numeric_limits<uint16_t>::max()) return 3;
+ if (addlInfo <= std::numeric_limits<uint32_t>::max()) return 5;
+ return 9;
+}
+
+uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end) {
+ size_t sz = headerSize(addlInfo);
+ if (end - pos < static_cast<ssize_t>(sz)) return nullptr;
+ switch (sz) {
+ case 1:
+ *pos++ = type | static_cast<uint8_t>(addlInfo);
+ return pos;
+ case 2:
+ *pos++ = type | ONE_BYTE_LENGTH;
+ *pos++ = static_cast<uint8_t>(addlInfo);
+ return pos;
+ case 3:
+ *pos++ = type | TWO_BYTE_LENGTH;
+ return writeBigEndian(static_cast<uint16_t>(addlInfo), pos);
+ case 5:
+ *pos++ = type | FOUR_BYTE_LENGTH;
+ return writeBigEndian(static_cast<uint32_t>(addlInfo), pos);
+ case 9:
+ *pos++ = type | EIGHT_BYTE_LENGTH;
+ return writeBigEndian(addlInfo, pos);
+ default:
+ CHECK(false); // Impossible to get here.
+ return nullptr;
+ }
+}
+
+void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback) {
+ size_t sz = headerSize(addlInfo);
+ switch (sz) {
+ case 1:
+ encodeCallback(type | static_cast<uint8_t>(addlInfo));
+ break;
+ case 2:
+ encodeCallback(type | ONE_BYTE_LENGTH);
+ encodeCallback(static_cast<uint8_t>(addlInfo));
+ break;
+ case 3:
+ encodeCallback(type | TWO_BYTE_LENGTH);
+ writeBigEndian(static_cast<uint16_t>(addlInfo), encodeCallback);
+ break;
+ case 5:
+ encodeCallback(type | FOUR_BYTE_LENGTH);
+ writeBigEndian(static_cast<uint32_t>(addlInfo), encodeCallback);
+ break;
+ case 9:
+ encodeCallback(type | EIGHT_BYTE_LENGTH);
+ writeBigEndian(addlInfo, encodeCallback);
+ break;
+ default:
+ CHECK(false); // Impossible to get here.
+ }
+}
+
+bool Item::operator==(const Item& other) const& {
+ if (type() != other.type()) return false;
+ switch (type()) {
+ case UINT:
+ return *asUint() == *(other.asUint());
+ case NINT:
+ return *asNint() == *(other.asNint());
+ case BSTR:
+ if (asBstr() != nullptr && other.asBstr() != nullptr) {
+ return *asBstr() == *(other.asBstr());
+ }
+ if (asViewBstr() != nullptr && other.asViewBstr() != nullptr) {
+ return *asViewBstr() == *(other.asViewBstr());
+ }
+ // Interesting corner case: comparing a Bstr and ViewBstr with
+ // identical contents. The function currently returns false for
+ // this case.
+ // TODO: if it should return true, this needs a deep comparison
+ return false;
+ case TSTR:
+ if (asTstr() != nullptr && other.asTstr() != nullptr) {
+ return *asTstr() == *(other.asTstr());
+ }
+ if (asViewTstr() != nullptr && other.asViewTstr() != nullptr) {
+ return *asViewTstr() == *(other.asViewTstr());
+ }
+ // Same corner case as Bstr
+ return false;
+ case ARRAY:
+ return *asArray() == *(other.asArray());
+ case MAP:
+ return *asMap() == *(other.asMap());
+ case SIMPLE:
+ return *asSimple() == *(other.asSimple());
+ case SEMANTIC:
+ return *asSemanticTag() == *(other.asSemanticTag());
+ default:
+ CHECK(false); // Impossible to get here.
+ return false;
+ }
+}
+
+Nint::Nint(int64_t v) : mValue(v) {
+ CHECK(v < 0);
+}
+
+bool Simple::operator==(const Simple& other) const& {
+ if (simpleType() != other.simpleType()) return false;
+
+ switch (simpleType()) {
+ case BOOLEAN:
+ return *asBool() == *(other.asBool());
+ case NULL_T:
+ return true;
+ default:
+ CHECK(false); // Impossible to get here.
+ return false;
+ }
+}
+
+uint8_t* Bstr::encode(uint8_t* pos, const uint8_t* end) const {
+ pos = encodeHeader(mValue.size(), pos, end);
+ if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
+ return std::copy(mValue.begin(), mValue.end(), pos);
+}
+
+void Bstr::encodeValue(EncodeCallback encodeCallback) const {
+ for (auto c : mValue) {
+ encodeCallback(c);
+ }
+}
+
+uint8_t* ViewBstr::encode(uint8_t* pos, const uint8_t* end) const {
+ pos = encodeHeader(mView.size(), pos, end);
+ if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr;
+ return std::copy(mView.begin(), mView.end(), pos);
+}
+
+void ViewBstr::encodeValue(EncodeCallback encodeCallback) const {
+ for (auto c : mView) {
+ encodeCallback(static_cast<uint8_t>(c));
+ }
+}
+
+uint8_t* Tstr::encode(uint8_t* pos, const uint8_t* end) const {
+ pos = encodeHeader(mValue.size(), pos, end);
+ if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
+ return std::copy(mValue.begin(), mValue.end(), pos);
+}
+
+void Tstr::encodeValue(EncodeCallback encodeCallback) const {
+ for (auto c : mValue) {
+ encodeCallback(static_cast<uint8_t>(c));
+ }
+}
+
+uint8_t* ViewTstr::encode(uint8_t* pos, const uint8_t* end) const {
+ pos = encodeHeader(mView.size(), pos, end);
+ if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr;
+ return std::copy(mView.begin(), mView.end(), pos);
+}
+
+void ViewTstr::encodeValue(EncodeCallback encodeCallback) const {
+ for (auto c : mView) {
+ encodeCallback(static_cast<uint8_t>(c));
+ }
+}
+
+bool Array::operator==(const Array& other) const& {
+ return size() == other.size()
+ // Can't use vector::operator== because the contents are pointers. std::equal lets us
+ // provide a predicate that does the dereferencing.
+ && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(),
+ [](auto& a, auto& b) -> bool { return *a == *b; });
+}
+
+uint8_t* Array::encode(uint8_t* pos, const uint8_t* end) const {
+ pos = encodeHeader(size(), pos, end);
+ if (!pos) return nullptr;
+ for (auto& entry : mEntries) {
+ pos = entry->encode(pos, end);
+ if (!pos) return nullptr;
+ }
+ return pos;
+}
+
+void Array::encode(EncodeCallback encodeCallback) const {
+ encodeHeader(size(), encodeCallback);
+ for (auto& entry : mEntries) {
+ entry->encode(encodeCallback);
+ }
+}
+
+std::unique_ptr<Item> Array::clone() const {
+ auto res = std::make_unique<Array>();
+ for (size_t i = 0; i < mEntries.size(); i++) {
+ res->add(mEntries[i]->clone());
+ }
+ return res;
+}
+
+bool Map::operator==(const Map& other) const& {
+ return size() == other.size()
+ // Can't use vector::operator== because the contents are pairs of pointers. std::equal
+ // lets us provide a predicate that does the dereferencing.
+ && std::equal(begin(), end(), other.begin(), [](auto& a, auto& b) {
+ return *a.first == *b.first && *a.second == *b.second;
+ });
+}
+
+uint8_t* Map::encode(uint8_t* pos, const uint8_t* end) const {
+ pos = encodeHeader(size(), pos, end);
+ if (!pos) return nullptr;
+ for (auto& entry : mEntries) {
+ pos = entry.first->encode(pos, end);
+ if (!pos) return nullptr;
+ pos = entry.second->encode(pos, end);
+ if (!pos) return nullptr;
+ }
+ return pos;
+}
+
+void Map::encode(EncodeCallback encodeCallback) const {
+ encodeHeader(size(), encodeCallback);
+ for (auto& entry : mEntries) {
+ entry.first->encode(encodeCallback);
+ entry.second->encode(encodeCallback);
+ }
+}
+
+bool Map::keyLess(const Item* a, const Item* b) {
+ // CBOR map canonicalization rules are:
+
+ // 1. If two keys have different lengths, the shorter one sorts earlier.
+ if (a->encodedSize() < b->encodedSize()) return true;
+ if (a->encodedSize() > b->encodedSize()) return false;
+
+ // 2. If two keys have the same length, the one with the lower value in (byte-wise) lexical
+ // order sorts earlier. This requires encoding both items.
+ auto encodedA = a->encode();
+ auto encodedB = b->encode();
+
+ return std::lexicographical_compare(encodedA.begin(), encodedA.end(), //
+ encodedB.begin(), encodedB.end());
+}
+
+void recursivelyCanonicalize(std::unique_ptr<Item>& item) {
+ switch (item->type()) {
+ case UINT:
+ case NINT:
+ case BSTR:
+ case TSTR:
+ case SIMPLE:
+ return;
+
+ case ARRAY:
+ std::for_each(item->asArray()->begin(), item->asArray()->end(),
+ recursivelyCanonicalize);
+ return;
+
+ case MAP:
+ item->asMap()->canonicalize(true /* recurse */);
+ return;
+
+ case SEMANTIC:
+ // This can't happen. SemanticTags delegate their type() method to the contained Item's
+ // type.
+ assert(false);
+ return;
+ }
+}
+
+Map& Map::canonicalize(bool recurse) & {
+ if (recurse) {
+ for (auto& entry : mEntries) {
+ recursivelyCanonicalize(entry.first);
+ recursivelyCanonicalize(entry.second);
+ }
+ }
+
+ if (size() < 2 || mCanonicalized) {
+ // Trivially or already canonical; do nothing.
+ return *this;
+ }
+
+ std::sort(begin(), end(),
+ [](auto& a, auto& b) { return keyLess(a.first.get(), b.first.get()); });
+ mCanonicalized = true;
+ return *this;
+}
+
+std::unique_ptr<Item> Map::clone() const {
+ auto res = std::make_unique<Map>();
+ for (auto& [key, value] : *this) {
+ res->add(key->clone(), value->clone());
+ }
+ res->mCanonicalized = mCanonicalized;
+ return res;
+}
+
+std::unique_ptr<Item> SemanticTag::clone() const {
+ return std::make_unique<SemanticTag>(mValue, mTaggedItem->clone());
+}
+
+uint8_t* SemanticTag::encode(uint8_t* pos, const uint8_t* end) const {
+ // Can't use the encodeHeader() method that calls type() to get the major type, since that will
+ // return the tagged Item's type.
+ pos = ::cppbor::encodeHeader(kMajorType, mValue, pos, end);
+ if (!pos) return nullptr;
+ return mTaggedItem->encode(pos, end);
+}
+
+void SemanticTag::encode(EncodeCallback encodeCallback) const {
+ // Can't use the encodeHeader() method that calls type() to get the major type, since that will
+ // return the tagged Item's type.
+ ::cppbor::encodeHeader(kMajorType, mValue, encodeCallback);
+ mTaggedItem->encode(encodeCallback);
+}
+
+size_t SemanticTag::semanticTagCount() const {
+ size_t levelCount = 1; // Count this level.
+ const SemanticTag* cur = this;
+ while (cur->mTaggedItem && (cur = cur->mTaggedItem->asSemanticTag()) != nullptr) ++levelCount;
+ return levelCount;
+}
+
+uint64_t SemanticTag::semanticTag(size_t nesting) const {
+ // Getting the value of a specific nested tag is a bit tricky, because we start with the outer
+ // tag and don't know how many are inside. We count the number of nesting levels to find out
+ // how many there are in total, then to get the one we want we have to walk down levelCount -
+ // nesting steps.
+ size_t levelCount = semanticTagCount();
+ if (nesting >= levelCount) return 0;
+
+ levelCount -= nesting;
+ const SemanticTag* cur = this;
+ while (--levelCount > 0) cur = cur->mTaggedItem->asSemanticTag();
+
+ return cur->mValue;
+}
+
+string prettyPrint(const Item* item, size_t maxBStrSize, const vector<string>& mapKeysToNotPrint) {
+ string out;
+ prettyPrintInternal(item, out, 0, maxBStrSize, mapKeysToNotPrint);
+ return out;
+}
+string prettyPrint(const vector<uint8_t>& encodedCbor, size_t maxBStrSize,
+ const vector<string>& mapKeysToNotPrint) {
+ auto [item, _, message] = parse(encodedCbor);
+ if (item == nullptr) {
+ return "";
+ }
+
+ return prettyPrint(item.get(), maxBStrSize, mapKeysToNotPrint);
+}
+
+} // namespace cppbor