1 #!/usr/bin/python 2 3 # DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER. 4 # 5 # Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved. 6 # 7 # The contents of this file are subject to the terms of either the GNU Lesser 8 # General Public License Version 2.1 only ("LGPL") or the Common Development and 9 # Distribution License ("CDDL")(collectively, the "License"). You may not use this 10 # file except in compliance with the License. You can obtain a copy of the CDDL at 11 # http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at 12 # http://www.opensource.org/licenses/lgpl-license.php. See the License for the 13 # specific language governing permissions and limitations under the License. When 14 # distributing the software, include this License Header Notice in each file and 15 # include the full text of the License in the License file as well as the 16 # following notice: 17 # 18 # NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE 19 # (CDDL) 20 # For Covered Software in this distribution, this License shall be governed by the 21 # laws of the State of California (excluding conflict-of-law provisions). 22 # Any litigation relating to this License shall be subject to the jurisdiction of 23 # the Federal Courts of the Northern District of California and the state courts 24 # of the State of California, with venue lying in Santa Clara County, California. 25 # 26 # Contributor(s): 27 # 28 # If you wish your version of this file to be governed by only the CDDL or only 29 # the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to 30 # include this software in this distribution under the [CDDL or LGPL Version 2.1] 31 # license." If you don't indicate a single choice of license, a recipient has the 32 # option to distribute your version of this file under either the CDDL or the LGPL 33 # Version 2.1, or to extend the choice of license to its licensees as provided 34 # above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL 35 # Version 2 license, then the option applies only if the new code is made subject 36 # to such option by the copyright holder. 37 38 __all__ = ['Trie', 'DATrie', 'match_longest', 'get_ambiguious_length'] 39 40 from math import log 41 import struct 42 43 class Trie (object): 44 class TrieNode: 45 def __init__ (self): 46 self.val = 0 47 self.trans = {} 48 49 def __init__(self): 50 self.root = Trie.TrieNode() 51 52 def add(self, word, value=1): 53 curr_node = self.root 54 for ch in word: 55 try: 56 curr_node = curr_node.trans[ch] 57 except: 58 curr_node.trans[ch] = Trie.TrieNode() 59 curr_node = curr_node.trans[ch] 60 61 curr_node.val = value 62 63 def walk (self, trienode, ch): 64 if ch in trienode.trans: 65 trienode = trienode.trans[ch] 66 return trienode, trienode.val 67 else: 68 return None, 0 69 70 class FlexibleList (list): 71 def __check_size (self, index): 72 if index >= len(self): 73 self.extend ([0] * (index-len(self)+1)) 74 75 def __getitem__ (self, index): 76 self.__check_size (index) 77 return list.__getitem__(self, index) 78 79 def __setitem__ (self, index, value): 80 self.__check_size (index) 81 return list.__setitem__(self, index, value) 82 83 def character_based_encoder (ch, range=('a', 'z')): 84 ret = ord(ch) - ord(range[0]) + 1 85 if ret <= 0: ret = ord(range[1]) + 1 86 return ret 87 88 class DATrie (object): 89 def __init__(self, chr_encoder=character_based_encoder): 90 self.root = 0 91 self.chr_encoder = chr_encoder 92 self.clear() 93 94 def clear (self): 95 self.base = FlexibleList () 96 self.check = FlexibleList () 97 self.value = FlexibleList () 98 99 def walk (self, s, ch): 100 c = self.chr_encoder (ch) 101 t = abs(self.base[s]) + c 102 103 if t<len(self.check) and self.check[t] == s and self.base[t]: 104 if self.value: 105 v = self.value[t] 106 else: 107 v = -1 if self.base[t] < 0 else 0 108 return t, v 109 else: 110 return 0, 0 111 112 def find_base (self, s, children, i=1): 113 if s == 0 or not children: 114 return s 115 116 i = max (i, 1) 117 loop_times = 0 118 while True: 119 for ch in children: 120 k = i + self.chr_encoder (ch) 121 if self.base[k] or self.check[k] or k == s: 122 loop_times += 1 123 i += int (log (loop_times, 2)) + 1 124 break 125 else: 126 break 127 128 return i 129 130 def build (self, words, values=None): 131 assert (not values or (len(words) == len(values))) 132 itval = iter(values) if values else None 133 134 trie = Trie() 135 for w in words: 136 trie.add (w, itval.next() if itval else -1) 137 138 self.construct_from_trie (trie, values!=None) 139 140 def construct_from_trie (self, trie, with_value=True, progress_cb=None, progress_cb_thr=100): 141 nodes = [(trie.root, 0)] 142 find_from = 1 143 loop_times = 0 144 145 while nodes: 146 trienode, s = nodes.pop(0) 147 find_from = b = self.find_base (s, trienode.trans, find_from) 148 self.base[s] = -b if trienode.val else b 149 if with_value: self.value[s] = trienode.val 150 151 for ch in trienode.trans: 152 c = self.chr_encoder (ch) 153 t = abs(self.base[s]) + c 154 self.check[t] = s if s else -1 155 156 nodes.append ((trienode.trans[ch], t)) 157 158 loop_times += 1 159 if loop_times == progress_cb_thr: 160 loop_times = 0 161 if progress_cb: 162 progress_cb () 163 164 for i in xrange (self.chr_encoder (max(trie.root.trans))+1): 165 if self.check[i] == -1: 166 self.check[i] = 0 167 168 def save (self, fname): 169 f = open (fname, 'w+') 170 l = len (self.base) 171 172 using_32bits = int (log (l, 2)) > 15 173 fmt_str = '%dl'%l if using_32bits else '%dh'%l 174 175 f.write (struct.pack ('L', l)) 176 f.write (struct.pack ('H', using_32bits)) 177 f.write (struct.pack ('H', 1 if self.value else 0)) 178 179 f.write (struct.pack (fmt_str, *self.base)) 180 f.write (struct.pack (fmt_str, *self.check)) 181 182 if self.value: 183 if len(self.value) < l: self.value[l-1] = 0 184 f.write (struct.pack ('%dl'%l, *self.value)) 185 186 f.close() 187 188 def load (self, fname): 189 f = open (fname, 'r') 190 191 l = struct.unpack ('L', f.read(4))[0] 192 using_32bits = struct.unpack ('H', f.read(2))[0] 193 has_value = struct.unpack ('H', f.read(2))[0] 194 195 fmt_str = '%dl'%l if using_32bits else '%dh'%l 196 elm_size = 4 if using_32bits else 2 197 198 self.base = struct.unpack (fmt_str, f.read(l*elm_size)) 199 self.check = struct.unpack (fmt_str, f.read(l*elm_size)) 200 self.value = struct.unpack ('%dl'%l, f.read(l*4)) if has_value else [] 201 202 f.close() 203 204 def match_longest (trie, word): 205 l = ret_l = ret_v = 0 206 curr_node = trie.root 207 208 for ch in word: 209 curr_node, val = trie.walk (curr_node, ch) 210 if not curr_node: 211 break 212 213 l += 1 214 if val: 215 ret_l, ret_v = l, val 216 217 return ret_v, ret_l 218 219 def get_ambiguious_length (trie, str, word_len): 220 i = 1 221 while i < word_len and i < len(str): 222 wid, l = match_longest(trie, str[i:]) 223 if word_len < i + l: 224 word_len = i + l 225 i += 1 226 return i 227 228 def test (): 229 from pinyin_data import valid_syllables 230 231 trie = Trie() 232 for s in valid_syllables: 233 trie.add (s, valid_syllables[s]) 234 235 for s in valid_syllables: 236 v, l = match_longest (trie, s+'b') 237 assert (len(s) == l and valid_syllables[s] == v) 238 239 datrie = DATrie() 240 datrie.construct_from_trie (trie) 241 242 datrie.save ('/tmp/trie_test') 243 datrie.load ('/tmp/trie_test') 244 245 for s in valid_syllables: 246 v, l = match_longest (datrie, s+'b') 247 assert (len(s) == l and valid_syllables[s] == v) 248 249 print 'test executed successfully' 250 251 if __name__ == "__main__": 252 test () 253