Home | History | Annotate | Download | only in slm
      1 /*
      2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
      3  *
      4  * Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
      5  *
      6  * The contents of this file are subject to the terms of either the GNU Lesser
      7  * General Public License Version 2.1 only ("LGPL") or the Common Development and
      8  * Distribution License ("CDDL")(collectively, the "License"). You may not use this
      9  * file except in compliance with the License. You can obtain a copy of the CDDL at
     10  * http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
     11  * http://www.opensource.org/licenses/lgpl-license.php. See the License for the
     12  * specific language governing permissions and limitations under the License. When
     13  * distributing the software, include this License Header Notice in each file and
     14  * include the full text of the License in the License file as well as the
     15  * following notice:
     16  *
     17  * NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
     18  * (CDDL)
     19  * For Covered Software in this distribution, this License shall be governed by the
     20  * laws of the State of California (excluding conflict-of-law provisions).
     21  * Any litigation relating to this License shall be subject to the jurisdiction of
     22  * the Federal Courts of the Northern District of California and the state courts
     23  * of the State of California, with venue lying in Santa Clara County, California.
     24  *
     25  * Contributor(s):
     26  *
     27  * If you wish your version of this file to be governed by only the CDDL or only
     28  * the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
     29  * include this software in this distribution under the [CDDL or LGPL Version 2.1]
     30  * license." If you don't indicate a single choice of license, a recipient has the
     31  * option to distribute your version of this file under either the CDDL or the LGPL
     32  * Version 2.1, or to extend the choice of license to its licensees as provided
     33  * above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
     34  * Version 2 license, then the option applies only if the new code is made subject
     35  * to such option by the copyright holder.
     36  */
     37 
     38 #ifdef HAVE_CONFIG_H
     39 #include <config.h>
     40 #endif
     41 
     42 #include <unistd.h>
     43 #include <fcntl.h>
     44 #include <sys/types.h>
     45 #include <sys/stat.h>
     46 #include <math.h>
     47 
     48 #include "slm.h"
     49 
     50 #ifdef HAVE_SYS_MMAN_H
     51 #include <sys/mman.h>
     52 #elif defined(BEOS_OS)
     53 #include <be/kernel/OS.h>
     54 #endif
     55 
     56 bool
     57 CThreadSlm::load(const char* fname, bool MMap)
     58 {
     59     int fd = open(fname, O_RDONLY);
     60     m_bufSize = lseek(fd, 0, SEEK_END);
     61     lseek(fd, 0, SEEK_SET);
     62 
     63     m_bMMap = MMap;
     64     if (m_bMMap) {
     65 #ifdef HAVE_SYS_MMAN_H
     66         void* p = mmap(NULL, m_bufSize, PROT_READ, MAP_SHARED, fd, 0);
     67         if (p == MAP_FAILED) {
     68             close(fd);
     69             return false;
     70         }
     71         m_buf = (char *)p;
     72 #elif defined(BEOS_OS)
     73         char *p = NULL;
     74         area_id area = create_area("tmp", (void**)&p, B_ANY_ADDRESS,
     75                                    (m_bufSize + (B_PAGE_SIZE - 1)) & ~(B_PAGE_SIZE - 1),
     76                                    B_NO_LOCK, B_READ_AREA | B_WRITE_AREA);
     77         if (area < 0) {
     78             close(fd);
     79             return false;
     80         }
     81         m_buf = p;
     82 
     83         for (ssize_t len = m_bufSize; len > 0; ) {
     84             ssize_t n = read(fd, p, len);
     85             if (n < 0) break;
     86             p += n;
     87             len -= n;
     88         }
     89 #else // Other OS
     90         #error "No implementation for mmap()"
     91 #endif // HAVE_SYS_MMAN_H
     92     } else {
     93         if ((m_buf = new char[m_bufSize]) == NULL) {
     94             close(fd);
     95             return false;
     96         }
     97         if (read(fd, m_buf, m_bufSize) != m_bufSize) {
     98             delete [] m_buf; m_buf = NULL;
     99             close(fd);
    100             return false;
    101         }
    102     }
    103     close(fd);
    104 
    105     m_N = *(unsigned*)m_buf;
    106     m_UseLogPr = *(((unsigned*)m_buf)+1);
    107     m_LevelSizes = ((unsigned*)m_buf)+2;
    108     m_prTable = (float*)(m_buf + 2*sizeof(unsigned) + (m_N+1)*sizeof(unsigned));
    109     m_bowTable = m_prTable + (1 << BITS_PR);
    110 
    111     TNode* pn = (TNode*)(m_bowTable + (1 << BITS_BOW));
    112 
    113     //Solaris CC would cause error in runtime if using some thing like
    114     //following even using (void**) conversion. So add PtrVoid definition
    115     //m_Levels = new (void*) [m_N + 1];
    116     m_Levels = new PtrVoid[m_N+1];
    117 
    118     for (unsigned lvl = 0; lvl <= m_N; ++lvl) {
    119         m_Levels[lvl] = (void*)pn;
    120         pn += m_LevelSizes[lvl];
    121     }
    122     return true;
    123 }
    124 
    125 void
    126 CThreadSlm::free()
    127 {
    128     delete [] m_Levels;
    129     if (m_buf) {
    130         if (m_bMMap) {
    131 #ifdef HAVE_SYS_MMAN_H
    132             munmap(m_buf, m_bufSize);
    133 #elif defined(BEOS_OS)
    134             delete_area(area_for(m_buf));
    135 #else // Other OS
    136             #error "No implementation for munmap()"
    137 #endif // HAVE_SYS_MMAN_H
    138         } else {
    139             delete [] m_buf;
    140         }
    141     }
    142     m_buf = NULL;
    143     m_Levels = NULL;
    144 }
    145 
    146 template<class NodeT>
    147 unsigned int
    148 find_id(NodeT* base, unsigned int h, unsigned int t, unsigned int id)
    149 {
    150     unsigned int tail = t;
    151     while (h < t) {
    152         int m = (h+t)/2;
    153         NodeT* pm = base+m;
    154         unsigned int thisId = pm->wid();
    155         if (thisId < id)
    156             h = m+1;
    157         else if (thisId > id)
    158             t = m;
    159         else
    160             return m;
    161     }
    162     return tail;
    163 }
    164 
    165 /**
    166 * return value as the model suggested. The history state must be historified
    167 * or the history's level should be 0. when level == 0 but idx != 0, the
    168 * history is a psuedo unigram state used for this model to combine another
    169 * bigram cache language model
    170 */
    171 double
    172 CThreadSlm::rawTransfer(TState history, unsigned int wid, TState& result)
    173 {
    174     unsigned int lvl = history.getLevel();
    175     unsigned int pos = history.getIdx();
    176 
    177     double cost = (m_UseLogPr)?0.0:1.0;
    178 
    179     // NON_Word id must be dealed with special, let it transfer to root
    180     // without any cost
    181     if (ID_NOT_WORD == wid) {
    182         result = 0;
    183         return cost;
    184     }
    185 
    186     while (true) {
    187         //for psuedo cache model unigram state
    188         TNode* pn = ((TNode *)m_Levels[lvl]) + ((lvl)?pos:0);
    189 
    190         unsigned int t = (pn+1)->ch();
    191 
    192         if (lvl < m_N-1) {
    193             TNode* pBase =(TNode*)m_Levels[lvl+1];
    194             unsigned int idx = find_id(pBase, pn->ch(), t, wid);
    195             if (idx != t) {
    196                 result.setIdx(idx);
    197                 result.setLevel(lvl+1);
    198                 double pr = m_prTable[pBase[idx].pr()];
    199                 return (m_UseLogPr)?(cost+pr):(cost*pr);
    200             }
    201 
    202         } else {
    203             TLeaf* pBase =(TLeaf*)m_Levels[lvl+1];
    204             unsigned int idx = find_id(pBase, pn->ch(), t, wid);
    205             if (idx != t) {
    206                 result.setIdx(idx);
    207                 result.setLevel(lvl+1);
    208                 double pr = m_prTable[pBase[idx].pr()];
    209                 return (m_UseLogPr)?(cost+pr):(cost*pr);
    210             }
    211 
    212         }
    213 
    214         if (m_UseLogPr)
    215             cost += m_bowTable[pn->bow()];
    216         else
    217             cost *= m_bowTable[pn->bow()];
    218         if (lvl == 0)
    219             break;
    220         lvl = pn->bol();
    221         pos = pn->bon();
    222     }
    223     result.setLevel(0);
    224     result.setIdx(0);
    225     if (m_UseLogPr)
    226         return cost + m_prTable[((TNode *)m_Levels[0])->pr()];
    227     else
    228         return cost * m_prTable[((TNode *)m_Levels[0])->pr()];
    229 }
    230 
    231 double
    232 CThreadSlm::transferNegLog(TState history, unsigned int wid, TState& result)
    233 {
    234     double cost = rawTransfer(history, wid, result);
    235     if (m_UseLogPr)
    236         return cost;
    237     else
    238         return -log(cost);
    239 }
    240 
    241 double
    242 CThreadSlm::transfer(TState history, unsigned int wid, TState& result)
    243 {
    244     double cost = rawTransfer(history, wid, result);
    245     if (!m_UseLogPr)
    246         return cost;
    247     else
    248         return exp(-cost);
    249 }
    250 
    251 unsigned int
    252 CThreadSlm::lastWordId(TState st)
    253 {
    254     unsigned int lvl = st.getLevel();
    255     if (lvl >= m_N) {
    256         const TLeaf* pn = ((const TLeaf *)m_Levels[m_N]) + st.getIdx();
    257         return pn->wid();
    258     } else if (lvl > 0) {
    259         const TNode *pn = ((const TNode *)m_Levels[st.getLevel()]) + st.getIdx();
    260         return pn->wid();
    261     } else {
    262         unsigned int idx = st.getIdx();
    263         if (idx == 0) {
    264             const TNode *pn = ((const TNode *)m_Levels[st.getLevel()]) + st.getIdx();
    265             return pn->wid();
    266         }
    267         return idx; // return the psuedo state word id
    268     }
    269 }
    270 
    271 CThreadSlm::TState
    272 CThreadSlm::history_state_of(TState st)
    273 {
    274     if (st.getLevel() >= m_N) {
    275         TLeaf* pl = ((TLeaf *)m_Levels[m_N]) + st.getIdx();
    276         return TState(pl->bol(), pl->bon());
    277     } else {
    278         TNode* pn = ((TNode *)m_Levels[st.getLevel()]) + st.getIdx();
    279         if (pn->ch() == (pn+1)->ch())
    280             return TState(pn->bol(), pn->bon());
    281         else
    282             return st;
    283     }
    284 }
    285 
    286 CThreadSlm::TState&
    287 CThreadSlm::historify(TState& st)
    288 {
    289     if (st.getLevel() >= m_N) {
    290         TLeaf* pl = ((TLeaf *)m_Levels[m_N]) + st.getIdx();
    291         st.setLevel(pl->bol());
    292         st.setIdx(pl->bon());
    293     } else {
    294         TNode* pn = ((TNode *)m_Levels[st.getLevel()]) + st.getIdx();
    295         if (pn->ch() == (pn+1)->ch()) {
    296             st.setLevel(pn->bol());
    297             st.setIdx(pn->bon());
    298         }
    299     }
    300     return st;
    301 }
    302