Home | History | Annotate | Download | only in python
      1 #!/usr/bin/python
      2 
      3 # DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
      4 # 
      5 # Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
      6 # 
      7 # The contents of this file are subject to the terms of either the GNU Lesser
      8 # General Public License Version 2.1 only ("LGPL") or the Common Development and
      9 # Distribution License ("CDDL")(collectively, the "License"). You may not use this
     10 # file except in compliance with the License. You can obtain a copy of the CDDL at
     11 # http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
     12 # http://www.opensource.org/licenses/lgpl-license.php. See the License for the 
     13 # specific language governing permissions and limitations under the License. When
     14 # distributing the software, include this License Header Notice in each file and
     15 # include the full text of the License in the License file as well as the
     16 # following notice:
     17 # 
     18 # NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
     19 # (CDDL)
     20 # For Covered Software in this distribution, this License shall be governed by the
     21 # laws of the State of California (excluding conflict-of-law provisions).
     22 # Any litigation relating to this License shall be subject to the jurisdiction of
     23 # the Federal Courts of the Northern District of California and the state courts
     24 # of the State of California, with venue lying in Santa Clara County, California.
     25 # 
     26 # Contributor(s):
     27 # 
     28 # If you wish your version of this file to be governed by only the CDDL or only
     29 # the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
     30 # include this software in this distribution under the [CDDL or LGPL Version 2.1]
     31 # license." If you don't indicate a single choice of license, a recipient has the
     32 # option to distribute your version of this file under either the CDDL or the LGPL
     33 # Version 2.1, or to extend the choice of license to its licensees as provided
     34 # above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
     35 # Version 2 license, then the option applies only if the new code is made subject
     36 # to such option by the copyright holder. 
     37 
     38 __all__ = ['Trie', 'DATrie', 'match_longest', 'get_ambiguious_length']
     39 
     40 from math import log
     41 import struct
     42 
     43 class Trie (object):
     44     class TrieNode:
     45         def __init__ (self):
     46             self.val = 0
     47             self.trans = {}
     48 
     49     def __init__(self):
     50         self.root = Trie.TrieNode()
     51 
     52     def add(self, word, value=1):
     53         curr_node = self.root
     54         for ch in word:
     55             try: 
     56                 curr_node = curr_node.trans[ch]
     57             except:
     58                 curr_node.trans[ch] = Trie.TrieNode()
     59                 curr_node = curr_node.trans[ch]
     60 
     61         curr_node.val = value
     62 
     63     def walk (self, trienode, ch):
     64         if ch in trienode.trans:
     65             trienode = trienode.trans[ch]
     66             return trienode, trienode.val
     67         else:
     68             return None, 0
     69 
     70 class FlexibleList (list):
     71     def __check_size (self, index):
     72         if index >= len(self):
     73             self.extend ([0] * (index-len(self)+1))
     74 
     75     def __getitem__ (self, index):
     76         self.__check_size (index)
     77         return list.__getitem__(self, index)
     78 
     79     def __setitem__ (self, index, value):
     80         self.__check_size (index)
     81         return list.__setitem__(self, index, value)
     82 
     83 def character_based_encoder (ch, range=('a', 'z')):
     84     ret = ord(ch) - ord(range[0]) + 1
     85     if ret <= 0: ret = ord(range[1]) + 1
     86     return ret
     87 
     88 class DATrie (object):
     89     def __init__(self, chr_encoder=character_based_encoder):
     90         self.root = 0
     91         self.chr_encoder = chr_encoder
     92         self.clear()
     93 
     94     def clear (self):
     95         self.base  = FlexibleList ()
     96         self.check = FlexibleList ()
     97         self.value = FlexibleList ()
     98 
     99     def walk (self, s, ch):
    100         c = self.chr_encoder (ch)
    101         t = abs(self.base[s]) + c
    102 
    103         if t<len(self.check) and self.check[t] == s and self.base[t]:
    104             if self.value: 
    105                 v = self.value[t]
    106             else: 
    107                 v = -1 if self.base[t] < 0 else 0
    108             return t, v
    109         else:
    110             return 0, 0
    111 
    112     def find_base (self, s, children, i=1):
    113         if s == 0 or not children:
    114             return s
    115 
    116         i = max (i, 1)
    117         loop_times = 0
    118         while True:
    119             for ch in children:
    120                 k = i + self.chr_encoder (ch)
    121                 if self.base[k] or self.check[k] or k == s:
    122                     loop_times += 1
    123                     i += int (log (loop_times, 2)) + 1
    124                     break
    125             else:
    126                 break
    127 
    128         return i
    129 
    130     def build (self, words, values=None):
    131         assert (not values or (len(words) == len(values)))
    132         itval = iter(values) if values else None
    133 
    134         trie = Trie()
    135         for w in words:
    136             trie.add (w, itval.next() if itval else -1)
    137 
    138         self.construct_from_trie (trie, values!=None)
    139 
    140     def construct_from_trie (self, trie, with_value=True, progress_cb=None, progress_cb_thr=100):
    141         nodes = [(trie.root, 0)]
    142         find_from = 1
    143         loop_times = 0
    144 
    145         while nodes:
    146             trienode, s = nodes.pop(0)
    147             find_from = b = self.find_base (s, trienode.trans, find_from)
    148             self.base[s] = -b if trienode.val else b
    149             if with_value: self.value[s] = trienode.val
    150 
    151             for ch in trienode.trans:
    152                 c = self.chr_encoder (ch)
    153                 t = abs(self.base[s]) + c
    154                 self.check[t] = s if s else -1
    155 
    156                 nodes.append ((trienode.trans[ch], t))
    157 
    158             loop_times += 1
    159             if loop_times == progress_cb_thr:
    160                 loop_times = 0
    161                 if progress_cb:
    162                     progress_cb ()
    163 
    164         for i in xrange (self.chr_encoder (max(trie.root.trans))+1):
    165             if self.check[i] == -1:
    166                 self.check[i] = 0
    167 
    168     def save (self, fname):
    169         f = open (fname, 'w+')
    170         l = len (self.base)
    171 
    172         using_32bits = int (log (l, 2)) > 15
    173         fmt_str = '%dl'%l if using_32bits else '%dh'%l
    174 
    175         f.write (struct.pack ('L', l))
    176         f.write (struct.pack ('H', using_32bits))
    177         f.write (struct.pack ('H', 1 if self.value else 0))
    178 
    179         f.write (struct.pack (fmt_str, *self.base))
    180         f.write (struct.pack (fmt_str, *self.check))
    181 
    182         if self.value:
    183             if len(self.value) < l: self.value[l-1] = 0
    184             f.write (struct.pack ('%dl'%l, *self.value))
    185 
    186         f.close()
    187 
    188     def load (self, fname):
    189         f = open (fname, 'r')
    190 
    191         l = struct.unpack ('L', f.read(4))[0]
    192         using_32bits = struct.unpack ('H', f.read(2))[0]
    193         has_value = struct.unpack ('H', f.read(2))[0]
    194 
    195         fmt_str = '%dl'%l if using_32bits else '%dh'%l
    196         elm_size = 4 if using_32bits else 2
    197 
    198         self.base  = struct.unpack (fmt_str, f.read(l*elm_size))
    199         self.check = struct.unpack (fmt_str, f.read(l*elm_size))
    200         self.value = struct.unpack ('%dl'%l, f.read(l*4)) if has_value else []
    201 
    202         f.close()
    203 
    204 def match_longest (trie, word):
    205     l = ret_l = ret_v = 0
    206     curr_node = trie.root
    207 
    208     for ch in word:
    209         curr_node, val = trie.walk (curr_node, ch)
    210         if not curr_node: 
    211             break
    212 
    213         l += 1
    214         if val: 
    215             ret_l, ret_v = l, val
    216 
    217     return ret_v, ret_l
    218 
    219 def get_ambiguious_length (trie, str, word_len):
    220     i = 1
    221     while i < word_len and i < len(str):
    222         wid, l = match_longest(trie, str[i:])
    223         if word_len < i + l:
    224             word_len = i + l
    225         i += 1
    226     return i
    227 
    228 def test ():
    229     from pinyin_data import valid_syllables
    230 
    231     trie = Trie()
    232     for s in valid_syllables:
    233         trie.add (s, valid_syllables[s])
    234 
    235     for s in valid_syllables:
    236         v, l = match_longest (trie, s+'b')
    237         assert (len(s) == l and valid_syllables[s] == v)
    238 
    239     datrie = DATrie()
    240     datrie.construct_from_trie (trie)
    241 
    242     datrie.save ('/tmp/trie_test')
    243     datrie.load ('/tmp/trie_test')
    244 
    245     for s in valid_syllables:
    246         v, l = match_longest (datrie, s+'b')
    247         assert (len(s) == l and valid_syllables[s] == v)
    248 
    249     print 'test executed successfully'
    250 
    251 if __name__ == "__main__":
    252     test ()
    253