OpenGrok

Cross Reference: slmprune.cpp
xref: /nv-g11n/inputmethod/sunpinyin/slm/src/slm/slmprune/slmprune.cpp
Home | History | Annotate | Line # | Download | only in slmprune
      1 /*
      2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
      3  *
      4  * Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
      5  *
      6  * The contents of this file are subject to the terms of either the GNU Lesser
      7  * General Public License Version 2.1 only ("LGPL") or the Common Development and
      8  * Distribution License ("CDDL")(collectively, the "License"). You may not use this
      9  * file except in compliance with the License. You can obtain a copy of the CDDL at
     10  * http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
     11  * http://www.opensource.org/licenses/lgpl-license.php. See the License for the
     12  * specific language governing permissions and limitations under the License. When
     13  * distributing the software, include this License Header Notice in each file and
     14  * include the full text of the License in the License file as well as the
     15  * following notice:
     16  *
     17  * NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
     18  * (CDDL)
     19  * For Covered Software in this distribution, this License shall be governed by the
     20  * laws of the State of California (excluding conflict-of-law provisions).
     21  * Any litigation relating to this License shall be subject to the jurisdiction of
     22  * the Federal Courts of the Northern District of California and the state courts
     23  * of the State of California, with venue lying in Santa Clara County, California.
     24  *
     25  * Contributor(s):
     26  *
     27  * If you wish your version of this file to be governed by only the CDDL or only
     28  * the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
     29  * include this software in this distribution under the [CDDL or LGPL Version 2.1]
     30  * license." If you don't indicate a single choice of license, a recipient has the
     31  * option to distribute your version of this file under either the CDDL or the LGPL
     32  * Version 2.1, or to extend the choice of license to its licensees as provided
     33  * above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
     34  * Version 2 license, then the option applies only if the new code is made subject
     35  * to such option by the copyright holder.
     36  */
     37 
     38 #ifdef HAVE_CONFIG_H
     39 #include "config.h"
     40 #endif
     41 
     42 #ifdef HAVE_ASSERT_H
     43 #include <assert.h>
     44 #endif
     45 
     46 #include <stdio.h>
     47 #include <math.h>
     48 
     49 #include "../sim_slm.h"
     50 #include <algorithm>
     51 
     52 class TNodeInfo {
     53 public:
     54     double d;
     55 #ifndef WORDS_BIGENDIAN
     56     unsigned child : 1;
     57     unsigned idx : 31;
     58 #else
     59     unsigned idx : 31;
     60     unsigned child : 1;
     61 #endif
     62 
     63 public:
     64     TNodeInfo(double distance=0.0, int pos=0, bool children=0) : d(distance)
     65     { idx = pos; child = (children==0)?0:1; }
     66 
     67     bool operator< (const TNodeInfo& r) const
     68     { return ((child ^ r.child) == 0)?(d < r.d):(child == 0); }
     69 
     70     bool operator==(const TNodeInfo& r) const
     71     { return (child == r.child && d == r.d); }
     72 };
     73 
     74 class CSlmPruner : public CSIMSlm {
     75 public:
     76     CSlmPruner() : CSIMSlm(), cut(NULL)
     77     { }
     78 
     79     ~CSlmPruner()
     80     { if (cut) delete [] cut; }
     81 
     82     void SetCut(int* nCut);
     83     void SetReserve(int* nReserve);
     84     void Prune();
     85     void Write(const char* filename);
     86 
     87 protected:
     88     void PruneLevel(int lvl);
     89     double CalcDistance(int lvl, int* idx, TSIMWordId* hw);
     90     void CalcBOW();
     91 
     92 protected:
     93     int* cut;
     94     int cache_level, cache_idx; // to accelerate the pruning method
     95     double cache_PA, cache_PB;
     96 };
     97 
     98 void CSlmPruner::Prune()
     99 {
    100     printf("Erasing items using Entropy distance"); fflush(stdout);
    101     for (int lvl=N; lvl>0; --lvl)
    102         PruneLevel(lvl);
    103     printf("\n"); fflush(stdout);
    104     CalcBOW();
    105 }
    106 void CSlmPruner::Write(const char* filename)
    107 {
    108     FILE* out = fopen(filename, "wb");
    109     fwrite(&N, sizeof(N), 1, out);
    110     fwrite(&bUseLogPr, sizeof(bUseLogPr), 1, out);
    111     fwrite(sz, sizeof(int), N+1, out);
    112     for (int i=0; i<N; ++i) {
    113         fwrite(level[i], sizeof(TNode), sz[i], out);
    114     }
    115     fwrite(level[N], sizeof(TLeaf), sz[N], out);
    116     fclose(out);
    117 }
    118 
    119 void CSlmPruner::SetReserve(int* nReserve)
    120 {
    121     cut = new int [N+1];
    122     cut[0] = 0;
    123     for (int lvl=1; lvl<=N; ++lvl) {
    124         cut[lvl] = sz[lvl] - 1 - nReserve[lvl];
    125         if (cut[lvl] < 0) cut[lvl] = 0;
    126     }
    127 }
    128 
    129 void CSlmPruner::SetCut(int* nCut)
    130 {
    131     cut = new int [N+1];
    132     cut[0] = 0;
    133     for (int lvl=1; lvl<=N; ++lvl)
    134         cut[lvl] = nCut[lvl];
    135 }
    136 
    137 template <class chIterator>
    138 int CutLevel(CSIMSlm::TNode* pfirst, CSIMSlm::TNode* plast, chIterator chfirst, chIterator chlast, bool bUseLogPr)
    139 {
    140    int idxfirst, idxchk;
    141    chIterator chchk = chfirst;
    142    for (idxfirst=idxchk=0; chchk != chlast; ++chchk, ++idxchk) {
    143         //cut item whoese pr == 1.0; and not psuedo tail
    144         if (chchk->pr != ((bUseLogPr)?0.0:1.0) || (chchk+1) == chlast) {
    145             if (idxfirst < idxchk) *chfirst = *chchk;
    146             while (pfirst != plast && pfirst->child <= idxchk)
    147                 pfirst++->child = idxfirst;
    148             ++idxfirst;
    149             ++chfirst;
    150         }
    151     }
    152     return idxfirst;
    153 }
    154 
    155 void CSlmPruner::PruneLevel(int lvl)
    156 {
    157     cache_level = cache_idx = -1;
    158 
    159     if (cut[lvl] <= 0) {
    160         printf("\n  Level %d (%d items), no need to cut as your command!", lvl, sz[lvl]-1); fflush(stdout);
    161         return;
    162     }
    163 
    164     printf("\n  Level %d (%d items), allocating...", lvl, sz[lvl]-1); fflush(stdout);
    165 
    166     int n = sz[lvl] - 1; //do not count last psuedo tail
    167     if (cut[lvl] >= n) cut[lvl] = n-1;
    168     TNodeInfo* pbuf = new TNodeInfo[n];
    169     TSIMWordId hw[16]; // it should be lvl+1, yet some compiler do not support it
    170     int idx[16];       // it should be lvl+1, yet some compiler do not support it
    171 
    172     printf(", Calculating..."); fflush(stdout);
    173     for (int i=0; i <=lvl; ++i)
    174         idx[i] = 0;
    175     while (idx[lvl] < n) {
    176         if (lvl == N) {
    177             hw[lvl] = (((TLeaf*)level[lvl])+idx[lvl])->id;
    178         } else {
    179             hw[lvl] = (((TNode*)level[lvl])+idx[lvl])->id;
    180         }
    181         for (int j=lvl-1; j >= 0; --j) {
    182             TNode* pnode = ((TNode*)level[j])+idx[j];
    183             for (; (pnode+1)->child <= idx[j+1]; ++pnode, ++idx[j])
    184                 ;
    185             hw[j] = pnode->id;
    186         }
    187         bool has_child = false;
    188         if (lvl != N) {
    189             TNode* pn = ((TNode*)level[lvl]) + idx[lvl];
    190             if ((pn+1)->child > pn->child)
    191                 has_child = true;
    192         }
    193         pbuf[idx[lvl]].child = (has_child)?1:0;
    194         pbuf[idx[lvl]].idx = idx[lvl];
    195         if (!has_child)
    196             pbuf[idx[lvl]].d = CalcDistance(lvl, idx, hw);
    197         ++idx[lvl];
    198     }
    199     printf(", sorting...");
    200     std::make_heap(pbuf, pbuf+n);
    201     std::sort_heap(pbuf, pbuf+n);
    202 
    203     int k = 0;
    204     // because pr in model can not be 1.0, so we use this to mark a item to be prune
    205     for (TNodeInfo* pinfo = pbuf; k < cut[lvl] && pinfo->child == 0; ++k, ++pinfo) {
    206         if (lvl == N) {
    207             if (bUseLogPr)
    208                 (((TLeaf*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0)
    209             else
    210                 (((TLeaf*)level[lvl]) + pinfo->idx)->pr = 1.0;
    211         } else {
    212             if (bUseLogPr)
    213                 (((TNode*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0)
    214             else
    215                 (((TNode*)level[lvl]) + pinfo->idx)->pr = 1.0; // -log(1.0)
    216         }
    217     }
    218     printf("(cut %d items), build parent ptr...", k); fflush(stdout);
    219     if (lvl == N) {
    220         k = CutLevel((TNode*)level[lvl-1], ((TNode*)level[lvl-1])+sz[lvl-1], (TLeaf*)level[lvl], ((TLeaf*)level[lvl])+sz[lvl], bUseLogPr);
    221     } else {
    222         k = CutLevel((TNode*)level[lvl-1], ((TNode*)level[lvl-1])+sz[lvl-1], (TNode*)level[lvl], ((TNode*)level[lvl])+sz[lvl], bUseLogPr);
    223     }
    224     sz[lvl] = k; //k is new size
    225     printf("done!");
    226     delete [] pbuf;
    227     cache_level = cache_idx = -1;
    228 }
    229 
    230 template<class chIterator>
    231 double CalcNodeBow(CSlmPruner* pruner, int lvl, TSIMWordId words[], chIterator chh, chIterator cht, bool bUseLogPr)
    232 {
    233     double sumnext = 0.0, sum=0.0;
    234     if (chh == cht)
    235         return 1.0;
    236     for (; chh < cht; ++chh) {
    237         if (bUseLogPr)
    238             sumnext += exp(-double(chh->pr));
    239         else
    240             sumnext += double(chh->pr);
    241         words[lvl+1] = chh->id;
    242         sum += pruner->getPr(lvl, words+2);
    243     }
    244     assert(sumnext >= 0.0 && sumnext < 1.0);
    245     assert(sum >= 0.0 && sum < 1.0);
    246     return (1.0-sumnext)/(1.0-sum);
    247 }
    248 
    249 void CSlmPruner::CalcBOW()
    250 {
    251     printf("\nUpdating back-off weight"); fflush(stdout);
    252     for (int lvl=0; lvl < N; ++lvl) {
    253         printf("\n    Level %d...", lvl); fflush(stdout);
    254         TNode* base[16]; //it should be lvl+1, yet some compiler do not support it
    255         int idx[16];     //it should be lvl+1, yet some compiler do not support it
    256         for (int i=0; i <= lvl; ++i) {
    257             base[i] = (TNode*)level[i];
    258             idx[i] = 0;
    259         }
    260         TSIMWordId words[17];   //it should be lvl+2, yet some compiler do not support it
    261         for (int lsz = sz[lvl]-1; idx[lvl] < lsz; ++idx[lvl]) {
    262             words[lvl] = base[lvl][idx[lvl]].id;
    263             for (int k=lvl-1; k >= 0; --k) {
    264                 while (base[k][idx[k]+1].child <= idx[k+1])
    265                     ++idx[k];
    266                 words[k] = base[k][idx[k]].id;
    267             }
    268             TNode & node = base[lvl][idx[lvl]];
    269             TNode & nodenext = *((&node)+1);
    270 
    271             double bow = 1.0;
    272             if (lvl == N-1) {
    273                 TLeaf* ch = (TLeaf*)level[lvl+1];
    274                 bow = CalcNodeBow(this, lvl, words, &(ch[node.child]), &(ch[nodenext.child]), bUseLogPr);
    275             } else {
    276                 TNode* ch = (TNode*)level[lvl+1];
    277                 bow = CalcNodeBow(this, lvl, words, &(ch[node.child]), &(ch[nodenext.child]), bUseLogPr);
    278             }
    279             if (bUseLogPr)
    280                 node.bow = PR_TYPE(-log(bow));
    281             else
    282                 node.bow = PR_TYPE(bow);
    283         }
    284     }
    285     printf("\n"); fflush(stdout);
    286 }
    287 
    288 double CSlmPruner::CalcDistance(int lvl, int* idx, TSIMWordId* hw)
    289 {
    290     double PA, PB, PHW, PH_W, PH, BOW, _BOW, pr, p_r;
    291     TSIMWordId w = hw[lvl];
    292 
    293     PH=1.0;
    294     TNode* parent = ((TNode*)level[lvl-1])+idx[lvl-1];
    295     if (bUseLogPr)
    296         BOW = exp(-double(parent->bow));  //Fix original bug to use the BOW directly
    297     else
    298         BOW = double(parent->bow);
    299 
    300     for (int i=1; i < lvl; ++i)
    301         PH *= getPr(i, hw+1+(lvl-i));
    302     assert(PH <= 1.0 && PH >0.0);
    303 
    304     if (lvl == N) {
    305         if (bUseLogPr)
    306             PHW = exp(-((((TLeaf*)level[lvl])+idx[lvl])->pr));
    307         else
    308             PHW = ((((TLeaf*)level[lvl])+idx[lvl])->pr);
    309         assert(w == (((TLeaf*)level[lvl])+idx[lvl])->id);
    310     } else {
    311         if (bUseLogPr)
    312             PHW = exp(-((((TNode*)level[lvl])+idx[lvl])->pr));
    313         else
    314             PHW = ((((TNode*)level[lvl])+idx[lvl])->pr);
    315         assert(w == (((TNode*)level[lvl])+idx[lvl])->id);
    316 
    317     }
    318     PH_W = getPr(lvl-1, hw+2);
    319     assert(PHW > 0.0 && PHW < 1.0);
    320     assert(PH_W > 0.0 && PH_W < 1.0);
    321 
    322     if (cache_level != lvl-1 || cache_idx != idx[lvl-1]) {
    323         cache_level = lvl-1;
    324         cache_idx = idx[lvl-1];
    325         cache_PA = cache_PB = 1.0;
    326         for (int h=parent->child, t = (parent+1)->child; h<t; ++h) {
    327             TSIMWordId id;
    328             if (lvl == N) {
    329                 if (bUseLogPr)
    330                     pr = exp(-((((TLeaf*)level[lvl])+h)->pr));
    331                 else
    332                     pr = ((((TLeaf*)level[lvl])+h)->pr);
    333                 id = (((TLeaf*)level[lvl])+h)->id;
    334 
    335             } else {
    336                 if (bUseLogPr)
    337                     pr = exp(-((((TNode*)level[lvl])+h)->pr));
    338                 else
    339                     pr = ((((TNode*)level[lvl])+h)->pr);
    340                 id = (((TNode*)level[lvl])+h)->id;
    341 
    342             }
    343             assert(pr > 0.0 && pr < 1.0);
    344             cache_PA -= pr;
    345 
    346             hw[lvl] = id;
    347             p_r = getPr(lvl-1, hw+2);  // Fix bug from pr = getPr(lvl-1, hw+1)
    348             assert(p_r > 0.0 && p_r < 1.0);
    349             cache_PB -= p_r;
    350         }
    351         assert(cache_PA > -0.01 && cache_PB > -0.01);
    352         if (cache_PA < 0.00001 || cache_PB < 0.00001) {
    353             printf("\n precision problem on %d gram:", lvl-1);
    354             for (int i=1; i < lvl; ++i) printf("%d ", idx[i]);
    355             printf("   ");
    356             if (cache_PA < 0.00001) {
    357                 printf("{1.0 - sigma p(w|h)} ==> 0.00001");
    358                 cache_PA = 0.00001;
    359             }
    360             if (cache_PB < 0.00001) {
    361                 printf("{1.0 - sigma p(w|h')} ==> 0.00001");
    362                 cache_PB = 0.00001;
    363             }
    364         }
    365     }
    366     PA = cache_PA;
    367     PB = cache_PB;
    368 
    369     _BOW = (PA+PHW) / (PB+PH_W); // Fix bug from "(1.0-PA+PHW)/(1.0-PB+PH_W);"
    370 
    371     assert(BOW > 0.0);
    372     assert(_BOW > 0.0);
    373     assert(PA+PHW < 1.01);     // %1 error rate
    374     assert(PB+PH_W < 1.01);    // %1 error rate
    375 
    376     /*
    377      * PH = P(h), PHW = P(w|h), PH_W = P(w|h'), _BOW = bow'(h) (the new bow)
    378      * BOW = bow(h) (the original bow), PA = sum_{w_i:C(w_i,h)=0} P(w_i|h),
    379      * PB = sum_{w_i:C(w_i,h)=0} P(w_i|h')
    380      */
    381     return -(PH * (PHW * (log(PH_W)+log(_BOW)-log(PHW)) + PA * (log(_BOW)-log(BOW)) ));
    382 }
    383 
    384 void ShowUsage(void)
    385 {
    386     printf("Usage:\n");
    387     printf("    slmprune input_slm result_slm [R|C] num1 num2...\n");
    388     printf("\nDescription:\n");
    389     printf("\
    390       This program uses entropy-based method to prune the size of back-off \n\
    391   language model 'input_slm' to a specific size and write to 'result_slm'. \n\
    392   the third parameter [R|C] means the following numbers is the number for\n\
    393   (R)eserve or (C)ut. If (C)ut, the num[k] means how many items in level K\n\
    394   would be cut. If (R)eserve, num[k] means how many item would be reserved\n\
    395   in level k. \n\
    396       Note that we do not ensure that during pruning process,  exactly the\n\
    397   the given number of items are cut or reserved, because some items may \n\
    398   contains high level children, so could not be cut. \n\
    399       Also it's your responsiblity to give right number of arguments based\n\
    400   on 'input_slm'.\n\
    401 \nSee Also:\n\
    402     To get information of the back-off language model, try 'slminfo'.\n\n");
    403 }
    404 
    405 int nCut[32];
    406 const char* srcfilename, *tgtfilename;
    407 
    408 int main(int argc, char* argv[])
    409 {
    410     memset(nCut, 0, sizeof(nCut));
    411     if (argc < 5) {
    412         ShowUsage(); exit(100);
    413     }
    414     srcfilename = argv[1];
    415     tgtfilename = argv[2];
    416     bool bCut = (argv[3][0] == 'C' || argv[3][0] == 'c');
    417 
    418     CSlmPruner pruner;
    419     printf("Reading language model %s...", srcfilename); fflush(stdout);
    420     pruner.Load(srcfilename);
    421     printf("done!\n"); fflush(stdout);
    422 
    423     for (int i=4; i < argc && i < 100; ++i)
    424         nCut[i-3] = atoi(argv[i]);
    425 
    426     if (bCut)
    427         pruner.SetCut(nCut);
    428     else
    429         pruner.SetReserve(nCut);
    430     pruner.Prune();
    431 
    432     printf("Writing target language model %s...", tgtfilename); fflush(stdout);
    433     pruner.Write(tgtfilename);
    434     printf("done!\n\n"); fflush(stdout);
    435 
    436     pruner.Free();
    437     return 0;
    438 }
    439