1 /* 2 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER. 3 * 4 * Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved. 5 * 6 * The contents of this file are subject to the terms of either the GNU Lesser 7 * General Public License Version 2.1 only ("LGPL") or the Common Development and 8 * Distribution License ("CDDL")(collectively, the "License"). You may not use this 9 * file except in compliance with the License. You can obtain a copy of the CDDL at 10 * http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at 11 * http://www.opensource.org/licenses/lgpl-license.php. See the License for the 12 * specific language governing permissions and limitations under the License. When 13 * distributing the software, include this License Header Notice in each file and 14 * include the full text of the License in the License file as well as the 15 * following notice: 16 * 17 * NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE 18 * (CDDL) 19 * For Covered Software in this distribution, this License shall be governed by the 20 * laws of the State of California (excluding conflict-of-law provisions). 21 * Any litigation relating to this License shall be subject to the jurisdiction of 22 * the Federal Courts of the Northern District of California and the state courts 23 * of the State of California, with venue lying in Santa Clara County, California. 24 * 25 * Contributor(s): 26 * 27 * If you wish your version of this file to be governed by only the CDDL or only 28 * the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to 29 * include this software in this distribution under the [CDDL or LGPL Version 2.1] 30 * license." If you don't indicate a single choice of license, a recipient has the 31 * option to distribute your version of this file under either the CDDL or the LGPL 32 * Version 2.1, or to extend the choice of license to its licensees as provided 33 * above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL 34 * Version 2 license, then the option applies only if the new code is made subject 35 * to such option by the copyright holder. 36 */ 37 38 #ifdef HAVE_CONFIG_H 39 #include "config.h" 40 #endif 41 42 #ifdef HAVE_ASSERT_H 43 #include <assert.h> 44 #endif 45 46 #include <stdio.h> 47 #include <math.h> 48 49 #include "../sim_slm.h" 50 #include <algorithm> 51 52 class TNodeInfo { 53 public: 54 double d; 55 #ifndef WORDS_BIGENDIAN 56 unsigned child : 1; 57 unsigned idx : 31; 58 #else 59 unsigned idx : 31; 60 unsigned child : 1; 61 #endif 62 63 public: 64 TNodeInfo(double distance=0.0, int pos=0, bool children=0) : d(distance) 65 { idx = pos; child = (children==0)?0:1; } 66 67 bool operator< (const TNodeInfo& r) const 68 { return ((child ^ r.child) == 0)?(d < r.d):(child == 0); } 69 70 bool operator==(const TNodeInfo& r) const 71 { return (child == r.child && d == r.d); } 72 }; 73 74 class CSlmPruner : public CSIMSlm { 75 public: 76 CSlmPruner() : CSIMSlm(), cut(NULL) 77 { } 78 79 ~CSlmPruner() 80 { if (cut) delete [] cut; } 81 82 void SetCut(int* nCut); 83 void SetReserve(int* nReserve); 84 void Prune(); 85 void Write(const char* filename); 86 87 protected: 88 void PruneLevel(int lvl); 89 double CalcDistance(int lvl, int* idx, TSIMWordId* hw); 90 void CalcBOW(); 91 92 protected: 93 int* cut; 94 int cache_level, cache_idx; // to accelerate the pruning method 95 double cache_PA, cache_PB; 96 }; 97 98 void CSlmPruner::Prune() 99 { 100 printf("Erasing items using Entropy distance"); fflush(stdout); 101 for (int lvl=N; lvl>0; --lvl) 102 PruneLevel(lvl); 103 printf("\n"); fflush(stdout); 104 CalcBOW(); 105 } 106 void CSlmPruner::Write(const char* filename) 107 { 108 FILE* out = fopen(filename, "wb"); 109 fwrite(&N, sizeof(N), 1, out); 110 fwrite(&bUseLogPr, sizeof(bUseLogPr), 1, out); 111 fwrite(sz, sizeof(int), N+1, out); 112 for (int i=0; i<N; ++i) { 113 fwrite(level[i], sizeof(TNode), sz[i], out); 114 } 115 fwrite(level[N], sizeof(TLeaf), sz[N], out); 116 fclose(out); 117 } 118 119 void CSlmPruner::SetReserve(int* nReserve) 120 { 121 cut = new int [N+1]; 122 cut[0] = 0; 123 for (int lvl=1; lvl<=N; ++lvl) { 124 cut[lvl] = sz[lvl] - 1 - nReserve[lvl]; 125 if (cut[lvl] < 0) cut[lvl] = 0; 126 } 127 } 128 129 void CSlmPruner::SetCut(int* nCut) 130 { 131 cut = new int [N+1]; 132 cut[0] = 0; 133 for (int lvl=1; lvl<=N; ++lvl) 134 cut[lvl] = nCut[lvl]; 135 } 136 137 template <class chIterator> 138 int CutLevel(CSIMSlm::TNode* pfirst, CSIMSlm::TNode* plast, chIterator chfirst, chIterator chlast, bool bUseLogPr) 139 { 140 int idxfirst, idxchk; 141 chIterator chchk = chfirst; 142 for (idxfirst=idxchk=0; chchk != chlast; ++chchk, ++idxchk) { 143 //cut item whoese pr == 1.0; and not psuedo tail 144 if (chchk->pr != ((bUseLogPr)?0.0:1.0) || (chchk+1) == chlast) { 145 if (idxfirst < idxchk) *chfirst = *chchk; 146 while (pfirst != plast && pfirst->child <= idxchk) 147 pfirst++->child = idxfirst; 148 ++idxfirst; 149 ++chfirst; 150 } 151 } 152 return idxfirst; 153 } 154 155 void CSlmPruner::PruneLevel(int lvl) 156 { 157 cache_level = cache_idx = -1; 158 159 if (cut[lvl] <= 0) { 160 printf("\n Level %d (%d items), no need to cut as your command!", lvl, sz[lvl]-1); fflush(stdout); 161 return; 162 } 163 164 printf("\n Level %d (%d items), allocating...", lvl, sz[lvl]-1); fflush(stdout); 165 166 int n = sz[lvl] - 1; //do not count last psuedo tail 167 if (cut[lvl] >= n) cut[lvl] = n-1; 168 TNodeInfo* pbuf = new TNodeInfo[n]; 169 TSIMWordId hw[16]; // it should be lvl+1, yet some compiler do not support it 170 int idx[16]; // it should be lvl+1, yet some compiler do not support it 171 172 printf(", Calculating..."); fflush(stdout); 173 for (int i=0; i <=lvl; ++i) 174 idx[i] = 0; 175 while (idx[lvl] < n) { 176 if (lvl == N) { 177 hw[lvl] = (((TLeaf*)level[lvl])+idx[lvl])->id; 178 } else { 179 hw[lvl] = (((TNode*)level[lvl])+idx[lvl])->id; 180 } 181 for (int j=lvl-1; j >= 0; --j) { 182 TNode* pnode = ((TNode*)level[j])+idx[j]; 183 for (; (pnode+1)->child <= idx[j+1]; ++pnode, ++idx[j]) 184 ; 185 hw[j] = pnode->id; 186 } 187 bool has_child = false; 188 if (lvl != N) { 189 TNode* pn = ((TNode*)level[lvl]) + idx[lvl]; 190 if ((pn+1)->child > pn->child) 191 has_child = true; 192 } 193 pbuf[idx[lvl]].child = (has_child)?1:0; 194 pbuf[idx[lvl]].idx = idx[lvl]; 195 if (!has_child) 196 pbuf[idx[lvl]].d = CalcDistance(lvl, idx, hw); 197 ++idx[lvl]; 198 } 199 printf(", sorting..."); 200 std::make_heap(pbuf, pbuf+n); 201 std::sort_heap(pbuf, pbuf+n); 202 203 int k = 0; 204 // because pr in model can not be 1.0, so we use this to mark a item to be prune 205 for (TNodeInfo* pinfo = pbuf; k < cut[lvl] && pinfo->child == 0; ++k, ++pinfo) { 206 if (lvl == N) { 207 if (bUseLogPr) 208 (((TLeaf*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0) 209 else 210 (((TLeaf*)level[lvl]) + pinfo->idx)->pr = 1.0; 211 } else { 212 if (bUseLogPr) 213 (((TNode*)level[lvl]) + pinfo->idx)->pr = 0.0; // -log(1.0) 214 else 215 (((TNode*)level[lvl]) + pinfo->idx)->pr = 1.0; // -log(1.0) 216 } 217 } 218 printf("(cut %d items), build parent ptr...", k); fflush(stdout); 219 if (lvl == N) { 220 k = CutLevel((TNode*)level[lvl-1], ((TNode*)level[lvl-1])+sz[lvl-1], (TLeaf*)level[lvl], ((TLeaf*)level[lvl])+sz[lvl], bUseLogPr); 221 } else { 222 k = CutLevel((TNode*)level[lvl-1], ((TNode*)level[lvl-1])+sz[lvl-1], (TNode*)level[lvl], ((TNode*)level[lvl])+sz[lvl], bUseLogPr); 223 } 224 sz[lvl] = k; //k is new size 225 printf("done!"); 226 delete [] pbuf; 227 cache_level = cache_idx = -1; 228 } 229 230 template<class chIterator> 231 double CalcNodeBow(CSlmPruner* pruner, int lvl, TSIMWordId words[], chIterator chh, chIterator cht, bool bUseLogPr) 232 { 233 double sumnext = 0.0, sum=0.0; 234 if (chh == cht) 235 return 1.0; 236 for (; chh < cht; ++chh) { 237 if (bUseLogPr) 238 sumnext += exp(-double(chh->pr)); 239 else 240 sumnext += double(chh->pr); 241 words[lvl+1] = chh->id; 242 sum += pruner->getPr(lvl, words+2); 243 } 244 assert(sumnext >= 0.0 && sumnext < 1.0); 245 assert(sum >= 0.0 && sum < 1.0); 246 return (1.0-sumnext)/(1.0-sum); 247 } 248 249 void CSlmPruner::CalcBOW() 250 { 251 printf("\nUpdating back-off weight"); fflush(stdout); 252 for (int lvl=0; lvl < N; ++lvl) { 253 printf("\n Level %d...", lvl); fflush(stdout); 254 TNode* base[16]; //it should be lvl+1, yet some compiler do not support it 255 int idx[16]; //it should be lvl+1, yet some compiler do not support it 256 for (int i=0; i <= lvl; ++i) { 257 base[i] = (TNode*)level[i]; 258 idx[i] = 0; 259 } 260 TSIMWordId words[17]; //it should be lvl+2, yet some compiler do not support it 261 for (int lsz = sz[lvl]-1; idx[lvl] < lsz; ++idx[lvl]) { 262 words[lvl] = base[lvl][idx[lvl]].id; 263 for (int k=lvl-1; k >= 0; --k) { 264 while (base[k][idx[k]+1].child <= idx[k+1]) 265 ++idx[k]; 266 words[k] = base[k][idx[k]].id; 267 } 268 TNode & node = base[lvl][idx[lvl]]; 269 TNode & nodenext = *((&node)+1); 270 271 double bow = 1.0; 272 if (lvl == N-1) { 273 TLeaf* ch = (TLeaf*)level[lvl+1]; 274 bow = CalcNodeBow(this, lvl, words, &(ch[node.child]), &(ch[nodenext.child]), bUseLogPr); 275 } else { 276 TNode* ch = (TNode*)level[lvl+1]; 277 bow = CalcNodeBow(this, lvl, words, &(ch[node.child]), &(ch[nodenext.child]), bUseLogPr); 278 } 279 if (bUseLogPr) 280 node.bow = PR_TYPE(-log(bow)); 281 else 282 node.bow = PR_TYPE(bow); 283 } 284 } 285 printf("\n"); fflush(stdout); 286 } 287 288 double CSlmPruner::CalcDistance(int lvl, int* idx, TSIMWordId* hw) 289 { 290 double PA, PB, PHW, PH_W, PH, BOW, _BOW, pr, p_r; 291 TSIMWordId w = hw[lvl]; 292 293 PH=1.0; 294 TNode* parent = ((TNode*)level[lvl-1])+idx[lvl-1]; 295 if (bUseLogPr) 296 BOW = exp(-double(parent->bow)); //Fix original bug to use the BOW directly 297 else 298 BOW = double(parent->bow); 299 300 for (int i=1; i < lvl; ++i) 301 PH *= getPr(i, hw+1+(lvl-i)); 302 assert(PH <= 1.0 && PH >0.0); 303 304 if (lvl == N) { 305 if (bUseLogPr) 306 PHW = exp(-((((TLeaf*)level[lvl])+idx[lvl])->pr)); 307 else 308 PHW = ((((TLeaf*)level[lvl])+idx[lvl])->pr); 309 assert(w == (((TLeaf*)level[lvl])+idx[lvl])->id); 310 } else { 311 if (bUseLogPr) 312 PHW = exp(-((((TNode*)level[lvl])+idx[lvl])->pr)); 313 else 314 PHW = ((((TNode*)level[lvl])+idx[lvl])->pr); 315 assert(w == (((TNode*)level[lvl])+idx[lvl])->id); 316 317 } 318 PH_W = getPr(lvl-1, hw+2); 319 assert(PHW > 0.0 && PHW < 1.0); 320 assert(PH_W > 0.0 && PH_W < 1.0); 321 322 if (cache_level != lvl-1 || cache_idx != idx[lvl-1]) { 323 cache_level = lvl-1; 324 cache_idx = idx[lvl-1]; 325 cache_PA = cache_PB = 1.0; 326 for (int h=parent->child, t = (parent+1)->child; h<t; ++h) { 327 TSIMWordId id; 328 if (lvl == N) { 329 if (bUseLogPr) 330 pr = exp(-((((TLeaf*)level[lvl])+h)->pr)); 331 else 332 pr = ((((TLeaf*)level[lvl])+h)->pr); 333 id = (((TLeaf*)level[lvl])+h)->id; 334 335 } else { 336 if (bUseLogPr) 337 pr = exp(-((((TNode*)level[lvl])+h)->pr)); 338 else 339 pr = ((((TNode*)level[lvl])+h)->pr); 340 id = (((TNode*)level[lvl])+h)->id; 341 342 } 343 assert(pr > 0.0 && pr < 1.0); 344 cache_PA -= pr; 345 346 hw[lvl] = id; 347 p_r = getPr(lvl-1, hw+2); // Fix bug from pr = getPr(lvl-1, hw+1) 348 assert(p_r > 0.0 && p_r < 1.0); 349 cache_PB -= p_r; 350 } 351 assert(cache_PA > -0.01 && cache_PB > -0.01); 352 if (cache_PA < 0.00001 || cache_PB < 0.00001) { 353 printf("\n precision problem on %d gram:", lvl-1); 354 for (int i=1; i < lvl; ++i) printf("%d ", idx[i]); 355 printf(" "); 356 if (cache_PA < 0.00001) { 357 printf("{1.0 - sigma p(w|h)} ==> 0.00001"); 358 cache_PA = 0.00001; 359 } 360 if (cache_PB < 0.00001) { 361 printf("{1.0 - sigma p(w|h')} ==> 0.00001"); 362 cache_PB = 0.00001; 363 } 364 } 365 } 366 PA = cache_PA; 367 PB = cache_PB; 368 369 _BOW = (PA+PHW) / (PB+PH_W); // Fix bug from "(1.0-PA+PHW)/(1.0-PB+PH_W);" 370 371 assert(BOW > 0.0); 372 assert(_BOW > 0.0); 373 assert(PA+PHW < 1.01); // %1 error rate 374 assert(PB+PH_W < 1.01); // %1 error rate 375 376 /* 377 * PH = P(h), PHW = P(w|h), PH_W = P(w|h'), _BOW = bow'(h) (the new bow) 378 * BOW = bow(h) (the original bow), PA = sum_{w_i:C(w_i,h)=0} P(w_i|h), 379 * PB = sum_{w_i:C(w_i,h)=0} P(w_i|h') 380 */ 381 return -(PH * (PHW * (log(PH_W)+log(_BOW)-log(PHW)) + PA * (log(_BOW)-log(BOW)) )); 382 } 383 384 void ShowUsage(void) 385 { 386 printf("Usage:\n"); 387 printf(" slmprune input_slm result_slm [R|C] num1 num2...\n"); 388 printf("\nDescription:\n"); 389 printf("\ 390 This program uses entropy-based method to prune the size of back-off \n\ 391 language model 'input_slm' to a specific size and write to 'result_slm'. \n\ 392 the third parameter [R|C] means the following numbers is the number for\n\ 393 (R)eserve or (C)ut. If (C)ut, the num[k] means how many items in level K\n\ 394 would be cut. If (R)eserve, num[k] means how many item would be reserved\n\ 395 in level k. \n\ 396 Note that we do not ensure that during pruning process, exactly the\n\ 397 the given number of items are cut or reserved, because some items may \n\ 398 contains high level children, so could not be cut. \n\ 399 Also it's your responsiblity to give right number of arguments based\n\ 400 on 'input_slm'.\n\ 401 \nSee Also:\n\ 402 To get information of the back-off language model, try 'slminfo'.\n\n"); 403 } 404 405 int nCut[32]; 406 const char* srcfilename, *tgtfilename; 407 408 int main(int argc, char* argv[]) 409 { 410 memset(nCut, 0, sizeof(nCut)); 411 if (argc < 5) { 412 ShowUsage(); exit(100); 413 } 414 srcfilename = argv[1]; 415 tgtfilename = argv[2]; 416 bool bCut = (argv[3][0] == 'C' || argv[3][0] == 'c'); 417 418 CSlmPruner pruner; 419 printf("Reading language model %s...", srcfilename); fflush(stdout); 420 pruner.Load(srcfilename); 421 printf("done!\n"); fflush(stdout); 422 423 for (int i=4; i < argc && i < 100; ++i) 424 nCut[i-3] = atoi(argv[i]); 425 426 if (bCut) 427 pruner.SetCut(nCut); 428 else 429 pruner.SetReserve(nCut); 430 pruner.Prune(); 431 432 printf("Writing target language model %s...", tgtfilename); fflush(stdout); 433 pruner.Write(tgtfilename); 434 printf("done!\n\n"); fflush(stdout); 435 436 pruner.Free(); 437 return 0; 438 } 439
