3 #include <wordring/serialize/serialize.hpp>
4 #include <wordring/serialize/serialize_iterator.hpp>
5 #include <wordring/static_vector/static_vector.hpp>
10 #include <initializer_list>
14 #include <type_traits>
25 using index_type = std::int32_t;
33 return lhs.m_base == rhs.m_base && lhs.m_check == rhs.m_check;
42 using index_type =
typename trie_node::index_type;
57 void operator=(index_type val)
59 if (val < 0)
throw std::invalid_argument(
"");
60 m_node->m_base = -val;
63 operator index_type()
const
65 assert(m_node->m_base <= 0);
66 return -m_node->m_base;
77 template <
typename Container>
80 template <
typename Allocator1>
83 template <
typename Container1>
86 template <
typename Container1>
90 using difference_type = std::ptrdiff_t;
91 using value_type = std::uint32_t;
92 using pointer = value_type*;
93 using reference = value_type&;
94 using iterator_category = std::input_iterator_tag;
98 using container = Container
const;
109 : m_c(std::addressof(c))
115 value_type operator*()
const
117 node_type const* p = m_c->data() + (m_index / 2);
118 return (m_index % 2) ? p->m_check : p->m_base;
136 std::uint32_t m_index;
139 template <
typename Container1>
142 assert(lhs.m_c == rhs.m_c);
143 return lhs.m_index == rhs.m_index;
146 template <
typename Container1>
147 inline bool operator!=(trie_heap_serialize_iterator<Container1>
const& lhs, trie_heap_serialize_iterator<Container1>
const& rhs)
149 assert(lhs.m_c == rhs.m_c);
150 return lhs.m_index != rhs.m_index;
199 template <
typename Allocator>
202 template <
typename Allocator1>
205 template <
typename Allocator1>
209 using index_type =
typename trie_node::index_type;
211 using container = std::vector<trie_node, Allocator>;
214 static constexpr std::uint16_t null_value = 256u;
217 using label_type = std::uint8_t;
218 using allocator_type = Allocator;
259 template <typename InputIterator, typename std::enable_if_t<std::is_integral_v<typename std::iterator_traits<InputIterator>::value_type>, std::nullptr_t> =
nullptr>
260 void assign(InputIterator first, InputIterator last)
262 using iterator_category =
typename std::iterator_traits<InputIterator>::iterator_category;
263 using value_type =
typename std::iterator_traits<InputIterator>::value_type;
264 using unsigned_type = std::make_unsigned_t<value_type>;
266 std::uint32_t constexpr n =
sizeof(index_type) /
sizeof(value_type);
268 static_assert(
sizeof(value_type) <=
sizeof(index_type));
271 if constexpr (std::is_same_v<iterator_category, std::random_access_iterator_tag>)
272 m_c.reserve(std::distance(first, last) / 2);
274 while (first != last)
276 std::make_unsigned_t<index_type> base = 0;
277 std::make_unsigned_t<index_type> check = 0;
279 if constexpr (n == 1)
286 for (std::uint32_t j = 0; j < n && first != last; ++j) base = (base << 8) + static_cast<unsigned_type>(*first++);
287 for (std::uint32_t j = 0; j < n && first != last; ++j) check = (check << 8) + static_cast<unsigned_type>(*first++);
290 m_c.push_back(
node_type{
static_cast<index_type
>(base),
static_cast<index_type
>(check) });
367 m_c.insert(m_c.begin(), 2,
trie_node{ 0, 0 });
370 void swap(
trie_heap& other) { m_c.swap(other.m_c); }
378 explicit trie_heap(allocator_type
const& alloc)
379 : m_c(2, { 0, 0 }, alloc)
390 trie_heap(std::initializer_list<trie_node> il, allocator_type
const& alloc = allocator_type())
395 index_type limit()
const {
return m_c.size(); }
397 void reserve(std::size_t n, index_type before = 0)
399 assert(0 <= before && before < limit());
401 index_type
id = m_c.size();
402 m_c.insert(m_c.end(), n, { 0, 0 });
404 node_type* d = m_c.data();
406 for (index_type i = -(d + before)->m_check; i != 0; i = -(d + before)->m_check)
412 for (index_type last = m_c.size();
id != last; before =
id++)
414 assert(before < limit());
415 (d + before)->m_check = -
id;
421 void allocate(index_type idx, index_type before = 0)
423 assert(1 < idx && idx < limit());
424 assert(m_c[idx].m_check <= 0);
425 assert(0 <= before && before < limit());
429 for (index_type i = -(d + before)->m_check; i != idx; i = -(d + before)->m_check)
435 (d + before)->m_check = (d + idx)->m_check;
436 (d + idx)->m_check = 0;
444 assert(!labels.empty());
445 assert(std::is_sorted(labels.begin(), labels.end()));
446 assert(before < limit());
448 if (limit() <= base + labels.back()) reserve(base + labels.back() + 1 - m_c.size());
451 for (std::uint16_t label : labels)
453 index_type idx = base + label;
454 assert(idx < limit());
455 if (1 <= (d + idx)->m_check)
continue;
457 for (index_type i = before; i != idx; i = -(d + before)->m_check)
463 (d + before)->m_check = (d + idx)->m_check;
464 (d + idx)->m_check = 0;
472 assert(1 <= parent && parent < limit());
473 assert(1 <= from && from < limit());
477 index_type last = std::min(from + null_value, limit() - 1);
478 for (index_type i = from; i <= last; ++i)
481 assert(0 <= i - from);
482 if ((m_c.data() + i)->m_check == parent) children.push_back(i - from);
486 std::set_union(labels.begin(), labels.end(), children.begin(), children.end(), std::back_inserter(all));
488 index_type before = 0;
489 index_type to =
locate(all, before);
493 for (std::uint16_t label : children)
495 index_type idx = from + label;
496 assert(idx < limit());
498 assert(to + label < limit());
499 (d + to + label)->m_base = (d + idx)->m_base;
500 (d + to + label)->m_check = (d + idx)->m_check;
503 index_type base = (d + idx)->m_base;
506 index_type last = std::min(base + null_value, limit() - 1);
507 for (index_type i = base; i <= last; ++i)
510 if ((d + i)->m_check == idx) (d + i)->m_check = to + label;
515 assert(parent < limit());
516 (d + parent)->m_base = to;
517 free(from, children);
526 index_type
free(index_type idx, index_type before = 0)
528 assert(1 < idx && idx < limit());
529 assert(0 <= before && before < limit());
533 for (index_type i = -(d + before)->m_check; i != 0 && i < idx; i = -(d + before)->m_check)
539 (d + idx)->m_base = 0;
540 (d + idx)->m_check = (d + before)->m_check;
541 (d + before)->m_check = -idx;
550 assert(1 <= base && base < limit());
551 assert(1 <= base && base + labels.back() < limit());
553 assert(!labels.empty());
554 assert(std::is_sorted(labels.begin(), labels.end()));
556 assert(0 <= before && before < limit());
558 for (
auto label : labels)
560 assert(base + label < limit());
561 before =
free(base + label, before);
574 assert(!labels.empty());
575 assert(std::is_sorted(labels.begin(), labels.end()));
580 std::uint16_t offset = labels.front();
583 index_type idx = -d->m_check;
586 for (; 0 < idx && idx <= offset; idx = -(d + idx)->m_check) before = idx;
591 for(; 0 != idx && !
is_free(idx - offset, labels); idx = -(d + idx)->m_check) before = idx;
592 if (idx != 0) base = idx - offset;
596 if (base == 0) base = std::max(limit() - offset, 1);
599 assert(0 <= before && before < base + labels.front());
609 assert(!labels.empty());
612 for (std::uint16_t label : labels)
614 if (label == 0 && base == 1)
return false;
616 index_type idx = base + label;
617 if (limit() <= idx)
break;
619 if (1 <= (d + idx)->m_check)
return false;
629 assert(1 <= parent && parent < limit());
631 assert(!labels.empty());
634 for (std::uint16_t label : labels)
636 if (label == 0 && base == 1)
return false;
638 index_type idx = base + label;
639 if (limit() <= idx)
break;
641 index_type check = (d + idx)->m_check;
642 if (1 <= check && parent != check)
return false;
653 assert(1 <= idx && idx < limit());
656 index_type base = (d + idx)->m_base;
658 return (base <= 0 && idx != 1)
659 || (1 <= base && base + null_value < limit() && (d + base + null_value)->m_check == idx);
666 assert(1 <= parent && parent < limit());
670 index_type base = (d + parent)->m_base;
671 assert(base < limit());
675 index_type last = std::min(base + null_value, limit());
676 assert(1 <= last && last <= limit());
678 for (index_type idx = base; 1 <= idx && idx < last; ++idx)
680 assert((d + idx)->m_check < limit());
681 if ((d + idx)->m_check == parent)
return true;
692 assert(1 <= parent && parent < limit());
696 index_type base = (d + parent)->m_base;
697 assert(base < limit());
701 index_type idx = base + null_value;
702 return idx < limit() && (d + idx)->m_check == parent;
712 assert(0 <= idx && idx < limit());
713 if (idx <= 1)
return false;
717 index_type parent = (d + idx)->m_check;
718 assert(1 <= parent && parent < limit());
720 index_type base = (d + parent)->m_base;
721 assert(1 <= base && base < limit());
723 index_type last = std::min(base + null_value, limit());
724 assert(1 <= last && last <= limit());
726 for (index_type i = base; i < last; ++i)
728 assert((d + i)->m_check < limit());
729 index_type check = (d + i)->m_check;
730 if (check == parent && i != idx)
return true;
740 assert(1 <= parent && parent < limit());
741 assert(0 <= idx && idx < limit());
743 if (idx <= 1)
return false;
747 index_type base = (d + parent)->m_base;
748 assert(1 <= base && base < limit());
750 index_type last = std::min(base + null_value, limit());
751 assert(1 <= last && last <= limit());
753 for (index_type i = base; i < last; ++i)
755 assert((d + i)->m_check < limit());
756 index_type check = (d + i)->m_check;
757 if (check == parent && i != idx)
return true;
766 index_type
at(index_type parent, std::uint16_t label)
const
768 assert(1 <= parent && parent < limit());
769 assert(
static_cast<index_type
>(label) <= null_value);
773 index_type base = (d + parent)->m_base;
775 index_type check = 0;
780 if (idx < limit()) check = (d + idx)->m_check;
781 if (check != parent) idx = 0;
787 index_type add(index_type parent, std::uint16_t label)
789 return add(parent, label_vector(1, label));
798 assert(parent < limit());
799 assert(!labels.empty());
800 assert(std::is_sorted(labels.begin(), labels.end()));
802 index_type before = 0;
803 index_type base = (m_c.data() + parent)->m_base;
804 assert(base < limit());
808 base =
locate(labels, before);
811 else if (
is_free(parent, base, labels))
allocate(base, labels, before);
812 else base =
relocate(parent, base, labels);
814 assert(base +
static_cast<index_type
>(labels.back()) < limit());
817 (d + parent)->m_base = base;
818 for (std::uint16_t label : labels) (d + base + label)->m_check = parent;
820 assert(1 <= base && base < limit());
829 template <
typename Allocator1>
830 inline std::ostream&
operator<<(std::ostream& os, trie_heap<Allocator1>
const& heap)
832 std::uint64_t n =
static_cast<std::uint64_t
>(heap.m_c.size()) *
sizeof(trie_node);
833 auto length = serialize(n);
834 auto it1 = length.begin();
835 auto it2 = length.end();
836 while (it1 != it2) os.put(*it1++);
840 while (it3 != it4) os.put(*it3++);
845 template <
typename Allocator1>
846 inline std::istream&
operator>>(std::istream& is, trie_heap<Allocator1>& heap)
850 auto it1 = std::istreambuf_iterator<char>(is);
851 auto it2 = std::istreambuf_iterator<char>();
854 it1 = deserialize(it1, it2, n);
856 for (std::uint64_t i = 0; i < n && it1 != it2; ++i)
858 std::int32_t base, check;
859 it1 = deserialize(it1, it2, base);
860 it1 = deserialize(it1, it2, check);
861 heap.m_c.push_back({ base, check });