Skip to content
Advertisement

Neural Machine Translation model predictions are off-by-one

Problem Summary

In the following example, my NMT model has high loss because it correctly predicts target_input instead of target_output.

Targetin   :  1  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
Targetout  :  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
Prediction :  3  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  3  3  3  3  3  3 10  3  3 10  3  3 10  3  3  9  3  4  4  4  4  4  3 10  3  3  9  3  3  6  6  6  6  6  6 10  9  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3  3  3  3  6  6  6  6  6  9  6  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  9  3  3  4  4  4  4  4  4  4  4  4  3 10  4  4  4  4  4  4  4  4  4  4  9  3  3  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  6  6  6  6  6  3  9  3  3  3  3  3  3  3  3  3  3  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  3  3  3  3  3 10  3  3  3  9  3  3 10  3  3  3  3  9  3  9  3 10  3  3  3  3  4  4  4  4  3 10  6  6  6  6  6  6  3  3 10  3  3  3  3  9  3  6  6  6  6  6  6  6  6  6  9  6  9  3  3  3  6  6  6  6  6  6  6  6  3  9  3  9  3  3  6  6  6  3  3  3  3  3  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6
Source     :  9 16  4  7 22 22 19  1 12 19 12 18  5 18  9 18  5  8 12 19 19  5  5 19 22  7 12 12  6 19  7  3 20  7  9 14  4 11 20 12  7  1 18  7  7  5 22  9 13 22 20 19  7 19  7 13  7 11 19 20  6 22 18 17 17  1 12 17 23  7 20  1 13  7 11 11 22  7 12  1 13 12  5  5 19 22  5  5 20  1  5  4 12  9  7 12  8 14 18 22 18 12 18 17 19  4 19 12 11 18  5  9  9  5 14  7 11  6  4 17 23  6  4  5 12  6  7 14  4 20  6  8 12 25  4 19  6  1  5  1  5 20  4 18 12 12  1 11 12  1 25 13 18 19  7 12  7  3  4 22  9  9 12  4  8  9 19  9 22 22 19  1 19  7  5 19  4  5 18 11 13  9  4 14 12 13 20 11 12 11  7  6  1 11 19 20  7 22 22 12 22 22  9  3  8 12 11 14 16  4 11  7 11  1  8  5  5  7 18 16 22 19  9 20  4 12 18  7 19  7  1 12 18 17 12 19  4 20  9  9  1 12  5 18 14 17 17  7  4 13 16 14 12 22 12 22 18  9 12 11  3 18  6 20  7  4 20  7  9  1  7 25 13  5 25 14 11  5 20  7 23 12  5 16 19 19 25 19  7 -1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

As is evident, the prediction matches up almost 100% with target_input instead of target_output, as it should (off-by-one). Loss and gradients are being calculated using target_output, so it is strange that predictions are matching up to target_input.

Model Overview

An NMT model predicts a sequence of words in a target language using a primary sequence of words in a source language. This is the framework behind Google Translate. Since NMT uses coupled-RNNs, it is supervised and required labelled target input and output.

NMT uses a source sequence, a target_input sequence, and a target_output sequence. In the example below, the encoder RNN (blue) uses the source input words to produce a meaning vector, which it passes to the decoder RNN (red), which uses the meaning vector to produce output.

When doing new predictions (inference), the decoder RNN uses its own previous output to seed the next prediction in the timestep. However, to improve training, it is allowed to seed itself with the correct previous prediction at each new timestep. This is why target_input is necessary for training.

enter image description here

Code to get an iterator with source, target_in, target_out

def get_batched_iterator(hparams, src_loc, tgt_loc):
    if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')):
        utils.integerize_raw_data()

    source_dataset = tf.data.TextLineDataset(src_loc)
    target_dataset = tf.data.TextLineDataset(tgt_loc)
    dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
    dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed)

    dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                                  tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
    dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0)))
    dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in)))
    # Proceed to batch and return iterator

NMT model core code

def __init__(self, hparams, iterator, mode):
        source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()

        # Lookup embeddings
        embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])
        encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)
        embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])
        decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)

        # Build and run Encoder LSTM
        encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)

        # Build and run Decoder LSTM with TrainingHelper and output projection layer
        decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
        projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)
        helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)
        decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)
        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
        logits = outputs.rnn_output

        if mode is 'TRAIN' or mode is 'EVAL':  # then calculate loss
            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)
            target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)
            self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size)

        if mode is 'TRAIN':  # then calculate/clip gradients, then optimize model
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)

            optimizer = tf.train.AdamOptimizer(hparams.l_rate)
            self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))

        if mode is 'EVAL':  # then allow access to input/output tensors to printout
            self.src = source
            self.tgt_in = target_in
            self.tgt_out = target_out
            self.logits = logits

Advertisement

Answer

The core issue with the NMT model used to predict a language-like syntax with a repetitive structure is that it becomes incentivized to simply predict whatever the past prediction was. Since it is fed the correct previous prediction at each step by TrainingHelper to speed up training, this artificially produces a local minimum that the model is unable to get out of.

The best option I have found is to weight the loss functions such the key points in the output sequence where the output is not repetitive are weighted more heavily. This will incentivize the model to get those correct, and not just repeat the past prediction.

User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement