libwordring
trie_heap.hpp
1 #pragma once
2 
3 #include <wordring/serialize/serialize.hpp>
4 #include <wordring/serialize/serialize_iterator.hpp>
5 #include <wordring/static_vector/static_vector.hpp>
6 
7 #include <algorithm>
8 #include <cassert>
9 #include <cstdint>
10 #include <initializer_list>
11 #include <istream>
12 #include <iterator>
13 #include <ostream>
14 #include <type_traits>
15 #include <vector>
16 
17 namespace wordring::detail
18 {
19  // ------------------------------------------------------------------------
20  // trie_node
21  // ------------------------------------------------------------------------
22 
23  struct trie_node
24  {
25  using index_type = std::int32_t;
26 
27  index_type m_base;
28  index_type m_check;
29  };
30 
31  inline bool operator==(trie_node const& lhs, trie_node const& rhs)
32  {
33  return lhs.m_base == rhs.m_base && lhs.m_check == rhs.m_check;
34  }
35 
36  // ------------------------------------------------------------------------
37  // trie_value_proxy
38  // ------------------------------------------------------------------------
39 
41  {
42  using index_type = typename trie_node::index_type;
43  using node_type = trie_node;
44 
45  node_type* m_node;
46 
48  : m_node(nullptr)
49  {
50  }
51 
53  : m_node(node)
54  {
55  }
56 
57  void operator=(index_type val)
58  {
59  if (val < 0) throw std::invalid_argument("");
60  m_node->m_base = -val;
61  }
62 
63  operator index_type() const
64  {
65  assert(m_node->m_base <= 0);
66  return -m_node->m_base;
67  }
68  };
69 
70  // ------------------------------------------------------------------------
71  // trie_heap_serialize_iterator
72  // ------------------------------------------------------------------------
73 
77  template <typename Container>
79  {
80  template <typename Allocator1>
81  friend class trie_heap;
82 
83  template <typename Container1>
85 
86  template <typename Container1>
88 
89  public:
90  using difference_type = std::ptrdiff_t;
91  using value_type = std::uint32_t;
92  using pointer = value_type*;
93  using reference = value_type&;
94  using iterator_category = std::input_iterator_tag;
95 
96  protected:
97  using node_type = trie_node;
98  using container = Container const;
99 
100  public:
102  : m_c(nullptr)
103  , m_index(0)
104  {
105  }
106 
107  protected:
108  trie_heap_serialize_iterator(container const& c, std::uint32_t n)
109  : m_c(std::addressof(c))
110  , m_index(n * 2)
111  {
112  }
113 
114  public:
115  value_type operator*() const
116  {
117  node_type const* p = m_c->data() + (m_index / 2);
118  return (m_index % 2) ? p->m_check : p->m_base;
119  }
120 
121  trie_heap_serialize_iterator& operator++()
122  {
123  ++m_index;
124  return *this;
125  }
126 
127  trie_heap_serialize_iterator operator++(int)
128  {
129  auto result = *this;
130  ++m_index;
131  return result;
132  }
133 
134  protected:
135  container* m_c;
136  std::uint32_t m_index;
137  };
138 
139  template <typename Container1>
140  inline bool operator==(trie_heap_serialize_iterator<Container1> const& lhs, trie_heap_serialize_iterator<Container1> const& rhs)
141  {
142  assert(lhs.m_c == rhs.m_c);
143  return lhs.m_index == rhs.m_index;
144  }
145 
146  template <typename Container1>
147  inline bool operator!=(trie_heap_serialize_iterator<Container1> const& lhs, trie_heap_serialize_iterator<Container1> const& rhs)
148  {
149  assert(lhs.m_c == rhs.m_c);
150  return lhs.m_index != rhs.m_index;
151  }
152 
153  // ------------------------------------------------------------------------
154  // trie_heap
155  // ------------------------------------------------------------------------
156 
199  template <typename Allocator>
200  class trie_heap
201  {
202  template <typename Allocator1>
203  friend std::ostream& operator<<(std::ostream&, trie_heap<Allocator1> const&);
204 
205  template <typename Allocator1>
206  friend std::istream& operator>>(std::istream&, trie_heap<Allocator1>&);
207 
208  protected:
209  using index_type = typename trie_node::index_type;
210  using node_type = trie_node;
211  using container = std::vector<trie_node, Allocator>;
213 
214  static constexpr std::uint16_t null_value = 256u;
215 
216  public:
217  using label_type = std::uint8_t;
218  using allocator_type = Allocator;
220 
221  public:
226  allocator_type get_allocator() const { return m_c.get_allocator(); }
227 
259  template <typename InputIterator, typename std::enable_if_t<std::is_integral_v<typename std::iterator_traits<InputIterator>::value_type>, std::nullptr_t> = nullptr>
260  void assign(InputIterator first, InputIterator last)
261  {
262  using iterator_category = typename std::iterator_traits<InputIterator>::iterator_category;
263  using value_type = typename std::iterator_traits<InputIterator>::value_type;
264  using unsigned_type = std::make_unsigned_t<value_type>;
265 
266  std::uint32_t constexpr n = sizeof(index_type) / sizeof(value_type);
267 
268  static_assert(sizeof(value_type) <= sizeof(index_type));
269 
270  m_c.clear();
271  if constexpr (std::is_same_v<iterator_category, std::random_access_iterator_tag>)
272  m_c.reserve(std::distance(first, last) / 2);
273 
274  while (first != last)
275  {
276  std::make_unsigned_t<index_type> base = 0;
277  std::make_unsigned_t<index_type> check = 0;
278 
279  if constexpr (n == 1)
280  {
281  base = *first++;
282  check = *first++;
283  }
284  else
285  {
286  for (std::uint32_t j = 0; j < n && first != last; ++j) base = (base << 8) + static_cast<unsigned_type>(*first++);
287  for (std::uint32_t j = 0; j < n && first != last; ++j) check = (check << 8) + static_cast<unsigned_type>(*first++);
288  }
289 
290  m_c.push_back(node_type{ static_cast<index_type>(base), static_cast<index_type>(check) });
291  }
292  }
293 
343  {
344  return serialize_iterator(m_c, 0);
345  }
346 
354  {
355  return serialize_iterator(m_c, m_c.size());
356  }
357 
358  // 変更 ---------------------------------------------------------------
359 
364  void clear() noexcept
365  {
366  m_c.clear();
367  m_c.insert(m_c.begin(), 2, trie_node{ 0, 0 });
368  }
369 
370  void swap(trie_heap& other) { m_c.swap(other.m_c); }
371 
372  protected:
373  trie_heap()
374  : m_c(2, { 0, 0 })
375  {
376  }
377 
378  explicit trie_heap(allocator_type const& alloc)
379  : m_c(2, { 0, 0 }, alloc)
380  {
381  }
382 
390  trie_heap(std::initializer_list<trie_node> il, allocator_type const& alloc = allocator_type())
391  : m_c(il, alloc)
392  {
393  }
394 
395  index_type limit() const { return m_c.size(); }
396 
397  void reserve(std::size_t n, index_type before = 0)
398  {
399  assert(0 <= before && before < limit());
400 
401  index_type id = m_c.size(); // reserveする先頭の番号
402  m_c.insert(m_c.end(), n, { 0, 0 });
403 
404  node_type* d = m_c.data();
405  // 値が0のINDEXを探す
406  for (index_type i = -(d + before)->m_check; i != 0; i = -(d + before)->m_check)
407  {
408  assert(i < limit());
409  before = i;
410  }
411  // CHECKを更新する
412  for (index_type last = m_c.size(); id != last; before = id++)
413  {
414  assert(before < limit());
415  (d + before)->m_check = -id;
416  }
417  }
418 
421  void allocate(index_type idx, index_type before = 0)
422  {
423  assert(1 < idx && idx < limit());
424  assert(m_c[idx].m_check <= 0);
425  assert(0 <= before && before < limit());
426 
427  node_type* d = m_c.data();
428  // CHECKがidxと一致するINDEXを探す
429  for (index_type i = -(d + before)->m_check; i != idx; i = -(d + before)->m_check)
430  {
431  assert(i < limit());
432  before = i;
433  }
434  // CHECKを更新する
435  (d + before)->m_check = (d + idx)->m_check;
436  (d + idx)->m_check = 0;
437  }
438 
441  void allocate(index_type base, label_vector const& labels, index_type before = 0)
442  {
443  assert(1 <= base);
444  assert(!labels.empty());
445  assert(std::is_sorted(labels.begin(), labels.end()));
446  assert(before < limit());
447 
448  if (limit() <= base + labels.back()) reserve(base + labels.back() + 1 - m_c.size());
449 
450  trie_node* d = m_c.data();
451  for (std::uint16_t label : labels)
452  {
453  index_type idx = base + label;
454  assert(idx < limit());
455  if (1 <= (d + idx)->m_check) continue; // 登録済み
456 
457  for (index_type i = before; i != idx; i = -(d + before)->m_check)
458  {
459  assert(i < limit());
460  before = i;
461  }
462 
463  (d + before)->m_check = (d + idx)->m_check;
464  (d + idx)->m_check = 0;
465  }
466  }
467 
470  index_type relocate(index_type parent, index_type from, label_vector const& labels)
471  {
472  assert(1 <= parent && parent < limit());
473  assert(1 <= from && from < limit());
474 
475  label_vector children;
476 
477  index_type last = std::min(from + null_value, limit() - 1);
478  for (index_type i = from; i <= last; ++i)
479  {
480  assert(i < limit());
481  assert(0 <= i - from);
482  if ((m_c.data() + i)->m_check == parent) children.push_back(i - from);
483  }
484 
485  label_vector all;
486  std::set_union(labels.begin(), labels.end(), children.begin(), children.end(), std::back_inserter(all));
487 
488  index_type before = 0;
489  index_type to = locate(all, before);
490  allocate(to, all, before);
491 
492  node_type* d = m_c.data();
493  for (std::uint16_t label : children)
494  {
495  index_type idx = from + label; // 子の旧INDEX
496  assert(idx < limit());
497 
498  assert(to + label < limit());
499  (d + to + label)->m_base = (d + idx)->m_base; // 子のBASEを置換
500  (d + to + label)->m_check = (d + idx)->m_check; // 子のCHECKを置換
501 
502  // 孫のCHECKを置き換え
503  index_type base = (d + idx)->m_base;
504  if (1 <= base)
505  {
506  index_type last = std::min(base + null_value, limit() - 1);
507  for (index_type i = base; i <= last; ++i)
508  {
509  assert(i < limit());
510  if ((d + i)->m_check == idx) (d + i)->m_check = to + label;
511  }
512  }
513  }
514 
515  assert(parent < limit());
516  (d + parent)->m_base = to;
517  free(from, children);
518 
519  return to;
520  }
521 
526  index_type free(index_type idx, index_type before = 0)
527  {
528  assert(1 < idx && idx < limit());
529  assert(0 <= before && before < limit());
530 
531  node_type* d = m_c.data();
532 
533  for (index_type i = -(d + before)->m_check; i != 0 && i < idx; i = -(d + before)->m_check)
534  {
535  assert(i < limit());
536  before = i;
537  }
538 
539  (d + idx)->m_base = 0;
540  (d + idx)->m_check = (d + before)->m_check;
541  (d + before)->m_check = -idx;
542 
543  return before;
544  }
545 
548  void free(index_type base, label_vector const& labels, index_type before = 0)
549  {
550  assert(1 <= base && base < limit());
551  assert(1 <= base && base + labels.back() < limit());
552 
553  assert(!labels.empty());
554  assert(std::is_sorted(labels.begin(), labels.end()));
555 
556  assert(0 <= before && before < limit());
557 
558  for (auto label : labels)
559  {
560  assert(base + label < limit());
561  before = free(base + label, before);
562  }
563  }
564 
572  index_type locate(label_vector const& labels, index_type& before) const
573  {
574  assert(!labels.empty());
575  assert(std::is_sorted(labels.begin(), labels.end()));
576 
577  index_type base = 0;
578  before = 0;
579 
580  std::uint16_t offset = labels.front();
581 
582  trie_node const* d = m_c.data();
583  index_type idx = -d->m_check;
584 
585  // BASEが負にならないよう、検索開始位置を設定する。
586  for (; 0 < idx && idx <= offset; idx = -(d + idx)->m_check) before = idx;
587 
588  // BASEを正に調整可能な位置に一つでもラべルを配置可能な空きノードがある場合、それに基づき計算する。
589  if (offset < idx)
590  {
591  for(; 0 != idx && !is_free(idx - offset, labels); idx = -(d + idx)->m_check) before = idx;
592  if (idx != 0) base = idx - offset;
593  }
594 
595  // そのような空きノードが無い場合、すべてのラベルが新規にreserveされるノードに配置される。
596  if (base == 0) base = std::max(limit() - offset, 1);
597 
598  assert(1 <= base);
599  assert(0 <= before && before < base + labels.front());
600 
601  return base;
602  }
603 
606  bool is_free(index_type base, label_vector const& labels) const
607  {
608  assert(1 <= base);
609  assert(!labels.empty());
610 
611  node_type const* d = m_c.data();
612  for (std::uint16_t label : labels)
613  {
614  if (label == 0 && base == 1) return false; // index 1のcheckは常に0のため
615 
616  index_type idx = base + label;
617  if (limit() <= idx) break;
618 
619  if (1 <= (d + idx)->m_check) return false;
620  }
621 
622  return true;
623  }
624 
627  bool is_free(index_type parent, index_type base, label_vector const& labels) const
628  {
629  assert(1 <= parent && parent < limit());
630  assert(1 <= base);
631  assert(!labels.empty());
632 
633  node_type const* d = m_c.data();
634  for (std::uint16_t label : labels)
635  {
636  if (label == 0 && base == 1) return false; // index 1のcheckは常に0のため
637 
638  index_type idx = base + label;
639  if (limit() <= idx) break;
640 
641  index_type check = (d + idx)->m_check;
642  if (1 <= check && parent != check) return false;
643  }
644 
645  return true;
646  }
647 
651  bool is_tail(index_type idx) const
652  {
653  assert(1 <= idx && idx < limit());
654 
655  node_type const* d = m_c.data();
656  index_type base = (d + idx)->m_base; // 子のBASEインデックス。
657 
658  return (base <= 0 && idx != 1)
659  || (1 <= base && base + null_value < limit() && (d + base + null_value)->m_check == idx);
660  }
661 
664  bool has_child(index_type parent) const
665  {
666  assert(1 <= parent && parent < limit());
667 
668  node_type const* d = m_c.data();
669 
670  index_type base = (d + parent)->m_base;
671  assert(base < limit());
672 
673  if (1 <= base)
674  {
675  index_type last = std::min(base + null_value, limit());
676  assert(1 <= last && last <= limit());
677 
678  for (index_type idx = base; 1 <= idx && idx < last; ++idx)
679  {
680  assert((d + idx)->m_check < limit());
681  if ((d + idx)->m_check == parent) return true;
682  }
683  }
684 
685  return false;
686  }
687 
690  bool has_null(index_type parent) const
691  {
692  assert(1 <= parent && parent < limit());
693 
694  node_type const* d = m_c.data();
695 
696  index_type base = (d + parent)->m_base;
697  assert(base < limit());
698 
699  if (1 <= base)
700  {
701  index_type idx = base + null_value;
702  return idx < limit() && (d + idx)->m_check == parent;
703  }
704 
705  return false;
706  }
707 
710  bool has_sibling(index_type idx) const
711  {
712  assert(0 <= idx && idx < limit());
713  if (idx <= 1) return false;
714 
715  node_type const* d = m_c.data();
716 
717  index_type parent = (d + idx)->m_check;
718  assert(1 <= parent && parent < limit());
719 
720  index_type base = (d + parent)->m_base;
721  assert(1 <= base && base < limit());
722 
723  index_type last = std::min(base + null_value, limit());
724  assert(1 <= last && last <= limit());
725 
726  for (index_type i = base; i < last; ++i)
727  {
728  assert((d + i)->m_check < limit());
729  index_type check = (d + i)->m_check;
730  if (check == parent && i != idx) return true;
731  }
732 
733  return false;
734  }
735 
738  bool has_sibling(index_type parent, index_type idx) const
739  {
740  assert(1 <= parent && parent < limit());
741  assert(0 <= idx && idx < limit());
742 
743  if (idx <= 1) return false;
744 
745  node_type const* d = m_c.data();
746 
747  index_type base = (d + parent)->m_base;
748  assert(1 <= base && base < limit());
749 
750  index_type last = std::min(base + null_value, limit());
751  assert(1 <= last && last <= limit());
752 
753  for (index_type i = base; i < last; ++i)
754  {
755  assert((d + i)->m_check < limit());
756  index_type check = (d + i)->m_check;
757  if (check == parent && i != idx) return true;
758  }
759 
760  return false;
761  }
762 
766  index_type at(index_type parent, std::uint16_t label) const
767  {
768  assert(1 <= parent && parent < limit());
769  assert(static_cast<index_type>(label) <= null_value);
770 
771  node_type const* d = m_c.data();
772 
773  index_type base = (d + parent)->m_base;
774  index_type idx = 0;
775  index_type check = 0;
776 
777  if (1 <= base)
778  {
779  idx = base + label;
780  if (idx < limit()) check = (d + idx)->m_check;
781  if (check != parent) idx = 0;
782  }
783 
784  return idx ;
785  }
786 
787  index_type add(index_type parent, std::uint16_t label)
788  {
789  return add(parent, label_vector(1, label));
790  }
791 
796  index_type add(index_type parent, label_vector const& labels)
797  {
798  assert(parent < limit());
799  assert(!labels.empty());
800  assert(std::is_sorted(labels.begin(), labels.end()));
801 
802  index_type before = 0;
803  index_type base = (m_c.data() + parent)->m_base; // 遷移先配置の起点(遷移先が定義されていない場合0)
804  assert(base < limit());
805 
806  if (base <= 0) // 子が無い。
807  {
808  base = locate(labels, before);
809  allocate(base, labels, before);
810  }
811  else if (is_free(parent, base, labels)) allocate(base, labels, before);
812  else base = relocate(parent, base, labels);
813 
814  assert(base + static_cast<index_type>(labels.back()) < limit());
815  node_type* d = m_c.data();
816 
817  (d + parent)->m_base = base;
818  for (std::uint16_t label : labels) (d + base + label)->m_check = parent;
819 
820  assert(1 <= base && base < limit());
821 
822  return base;
823  }
824 
825  protected:
826  container m_c;
827  };
828 
829  template <typename Allocator1>
830  inline std::ostream& operator<<(std::ostream& os, trie_heap<Allocator1> const& heap)
831  {
832  std::uint64_t n = static_cast<std::uint64_t>(heap.m_c.size()) * sizeof(trie_node);
833  auto length = serialize(n);
834  auto it1 = length.begin();
835  auto it2 = length.end();
836  while (it1 != it2) os.put(*it1++);
837 
838  auto it3 = serialize_iterator(heap.ibegin());
839  auto it4 = serialize_iterator(heap.iend());
840  while (it3 != it4) os.put(*it3++);
841 
842  return os;
843  }
844 
845  template <typename Allocator1>
846  inline std::istream& operator>>(std::istream& is, trie_heap<Allocator1>& heap)
847  {
848  heap.m_c.clear();
849 
850  auto it1 = std::istreambuf_iterator<char>(is);
851  auto it2 = std::istreambuf_iterator<char>();
852 
853  std::uint64_t n;
854  it1 = deserialize(it1, it2, n);
855 
856  for (std::uint64_t i = 0; i < n && it1 != it2; ++i)
857  {
858  std::int32_t base, check;
859  it1 = deserialize(it1, it2, base);
860  it1 = deserialize(it1, it2, check);
861  heap.m_c.push_back({ base, check });
862  }
863 
864  return is;
865  }
866 }
wordring::serialize_iterator
任意型の整数列に対するイテレータをバイトを返すイテレータへ変換する
Definition: serialize_iterator.hpp:50
wordring::detail::trie_heap::has_null
bool has_null(index_type parent) const
Definition: trie_heap.hpp:690
wordring::detail::trie_heap::free
index_type free(index_type idx, index_type before=0)
Definition: trie_heap.hpp:526
wordring::detail::trie_heap::is_tail
bool is_tail(index_type idx) const
Definition: trie_heap.hpp:651
wordring::detail::trie_heap::has_child
bool has_child(index_type parent) const
Definition: trie_heap.hpp:664
wordring::detail::trie_heap::is_free
bool is_free(index_type base, label_vector const &labels) const
Definition: trie_heap.hpp:606
wordring::static_vector< std::uint16_t, 257 >
wordring::detail::operator<<
std::ostream & operator<<(std::ostream &os, stable_trie_base< Allocator1 > const &trie)
ストリームへ出力する
Definition: stable_trie_base.hpp:752
wordring::detail::trie_heap::assign
void assign(InputIterator first, InputIterator last)
直列化データから割り当てる
Definition: trie_heap.hpp:260
wordring::detail::trie_heap::clear
void clear() noexcept
すべての要素を削除する
Definition: trie_heap.hpp:364
wordring::detail::trie_heap::allocate
void allocate(index_type idx, index_type before=0)
Definition: trie_heap.hpp:421
wordring::detail::trie_heap::get_allocator
allocator_type get_allocator() const
コンテナに関連付けられているアロケータを返す
Definition: trie_heap.hpp:226
wordring::detail::trie_heap::is_free
bool is_free(index_type parent, index_type base, label_vector const &labels) const
Definition: trie_heap.hpp:627
wordring::detail::trie_heap::relocate
index_type relocate(index_type parent, index_type from, label_vector const &labels)
Definition: trie_heap.hpp:470
wordring::detail::trie_node
Definition: trie_heap.hpp:23
wordring::detail::trie_heap::allocate
void allocate(index_type base, label_vector const &labels, index_type before=0)
Definition: trie_heap.hpp:441
wordring::detail::trie_heap::trie_heap
trie_heap(std::initializer_list< trie_node > il, allocator_type const &alloc=allocator_type())
初期化子リストから構築する
Definition: trie_heap.hpp:390
wordring::detail::trie_heap::ibegin
serialize_iterator ibegin() const
直列化用のイテレータを返す
Definition: trie_heap.hpp:342
wordring::detail
wordring::detail::trie_heap
ダブル・アレイによるTrie実装のメモリー管理を行う
Definition: trie_heap.hpp:200
wordring::detail::trie_heap::at
index_type at(index_type parent, std::uint16_t label) const
Definition: trie_heap.hpp:766
wordring::detail::trie_heap::add
index_type add(index_type parent, label_vector const &labels)
Definition: trie_heap.hpp:796
wordring::detail::trie_heap::has_sibling
bool has_sibling(index_type parent, index_type idx) const
Definition: trie_heap.hpp:738
wordring::detail::trie_heap::has_sibling
bool has_sibling(index_type idx) const
Definition: trie_heap.hpp:710
wordring::detail::trie_value_proxy
Definition: trie_heap.hpp:40
wordring::detail::trie_heap::iend
serialize_iterator iend() const
直列化用のイテレータを返す
Definition: trie_heap.hpp:353
wordring::detail::trie_heap::locate
index_type locate(label_vector const &labels, index_type &before) const
Definition: trie_heap.hpp:572
wordring::detail::operator>>
std::istream & operator>>(std::istream &is, stable_trie_base< Allocator1 > &trie)
ストリームから入力する
Definition: stable_trie_base.hpp:763
wordring::detail::trie_heap::free
void free(index_type base, label_vector const &labels, index_type before=0)
Definition: trie_heap.hpp:548
wordring::detail::trie_heap_serialize_iterator
Definition: trie_heap.hpp:78