mpeg12: check scantable indices in all decode_block functions
[libav.git] / libavcodec / mpeg12dec.c
index 49b7d1e..f0e390d 100644 (file)
@@ -119,6 +119,15 @@ static int mpeg_decode_motion(MpegEncContext *s, int fcode, int pred)
     return sign_extend(val, 5 + shift);
 }
 
+#define check_scantable_index(ctx, x)                                     \
+    do {                                                                  \
+        if ((x) > 63) {                                                   \
+            av_log(ctx->avctx, AV_LOG_ERROR, "ac-tex damaged at %d %d\n", \
+                   ctx->mb_x, ctx->mb_y);                                 \
+            return AVERROR_INVALIDDATA;                                   \
+        }                                                                 \
+    } while (0)                                                           \
+
 static inline int mpeg1_decode_block_intra(MpegEncContext *s, int16_t *block, int n)
 {
     int level, dc, diff, i, j, run;
@@ -150,6 +159,7 @@ static inline int mpeg1_decode_block_intra(MpegEncContext *s, int16_t *block, in
                 break;
             } else if (level != 0) {
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 level = (level * qscale * quant_matrix[j]) >> 4;
                 level = (level - 1) | 1;
@@ -166,6 +176,7 @@ static inline int mpeg1_decode_block_intra(MpegEncContext *s, int16_t *block, in
                     level = SHOW_UBITS(re, &s->gb, 8)      ; LAST_SKIP_BITS(re, &s->gb, 8);
                 }
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 if (level < 0) {
                     level = -level;
@@ -177,10 +188,6 @@ static inline int mpeg1_decode_block_intra(MpegEncContext *s, int16_t *block, in
                     level = (level - 1) | 1;
                 }
             }
-            if (i > 63) {
-                av_log(s->avctx, AV_LOG_ERROR, "ac-tex damaged at %d %d\n", s->mb_x, s->mb_y);
-                return -1;
-            }
 
             block[j] = level;
         }
@@ -225,6 +232,7 @@ static inline int mpeg1_decode_block_inter(MpegEncContext *s, int16_t *block, in
 
             if (level != 0) {
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 level = ((level * 2 + 1) * qscale * quant_matrix[j]) >> 5;
                 level = (level - 1) | 1;
@@ -241,6 +249,7 @@ static inline int mpeg1_decode_block_inter(MpegEncContext *s, int16_t *block, in
                     level = SHOW_UBITS(re, &s->gb, 8)      ; SKIP_BITS(re, &s->gb, 8);
                 }
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 if (level < 0) {
                     level = -level;
@@ -252,10 +261,6 @@ static inline int mpeg1_decode_block_inter(MpegEncContext *s, int16_t *block, in
                     level = (level - 1) | 1;
                 }
             }
-            if (i > 63) {
-                av_log(s->avctx, AV_LOG_ERROR, "ac-tex damaged at %d %d\n", s->mb_x, s->mb_y);
-                return -1;
-            }
 
             block[j] = level;
             if (((int32_t)GET_CACHE(re, &s->gb)) <= (int32_t)0xBFFFFFFF)
@@ -300,6 +305,7 @@ static inline int mpeg1_fast_decode_block_inter(MpegEncContext *s, int16_t *bloc
 
             if (level != 0) {
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 level = ((level * 2 + 1) * qscale) >> 1;
                 level = (level - 1) | 1;
@@ -316,6 +322,7 @@ static inline int mpeg1_fast_decode_block_inter(MpegEncContext *s, int16_t *bloc
                     level = SHOW_UBITS(re, &s->gb, 8)      ; SKIP_BITS(re, &s->gb, 8);
                 }
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 if (level < 0) {
                     level = -level;
@@ -381,6 +388,7 @@ static inline int mpeg2_decode_block_non_intra(MpegEncContext *s, int16_t *block
 
             if (level != 0) {
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 level = ((level * 2 + 1) * qscale * quant_matrix[j]) >> 5;
                 level = (level ^ SHOW_SBITS(re, &s->gb, 1)) - SHOW_SBITS(re, &s->gb, 1);
@@ -392,6 +400,7 @@ static inline int mpeg2_decode_block_non_intra(MpegEncContext *s, int16_t *block
                 level = SHOW_SBITS(re, &s->gb, 12); SKIP_BITS(re, &s->gb, 12);
 
                 i += run;
+                check_scantable_index(s, i);
                 j = scantable[i];
                 if (level < 0) {
                     level = ((-level * 2 + 1) * qscale * quant_matrix[j]) >> 5;
@@ -400,10 +409,6 @@ static inline int mpeg2_decode_block_non_intra(MpegEncContext *s, int16_t *block
                     level = ((level * 2 + 1) * qscale * quant_matrix[j]) >> 5;
                 }
             }
-            if (i > 63) {
-                av_log(s->avctx, AV_LOG_ERROR, "ac-tex damaged at %d %d\n", s->mb_x, s->mb_y);
-                return -1;
-            }
 
             mismatch ^= level;
             block[j]  = level;
@@ -450,6 +455,7 @@ static inline int mpeg2_fast_decode_block_non_intra(MpegEncContext *s,
 
         if (level != 0) {
             i += run;
+            check_scantable_index(s, i);
             j  = scantable[i];
             level = ((level * 2 + 1) * qscale) >> 1;
             level = (level ^ SHOW_SBITS(re, &s->gb, 1)) - SHOW_SBITS(re, &s->gb, 1);
@@ -461,6 +467,7 @@ static inline int mpeg2_fast_decode_block_non_intra(MpegEncContext *s,
             level = SHOW_SBITS(re, &s->gb, 12); SKIP_BITS(re, &s->gb, 12);
 
             i += run;
+            check_scantable_index(s, i);
             j  = scantable[i];
             if (level < 0) {
                 level = ((-level * 2 + 1) * qscale) >> 1;
@@ -527,6 +534,7 @@ static inline int mpeg2_decode_block_intra(MpegEncContext *s, int16_t *block, in
                 break;
             } else if (level != 0) {
                 i += run;
+                check_scantable_index(s, i);
                 j  = scantable[i];
                 level = (level * qscale * quant_matrix[j]) >> 4;
                 level = (level ^ SHOW_SBITS(re, &s->gb, 1)) - SHOW_SBITS(re, &s->gb, 1);
@@ -537,6 +545,7 @@ static inline int mpeg2_decode_block_intra(MpegEncContext *s, int16_t *block, in
                 UPDATE_CACHE(re, &s->gb);
                 level = SHOW_SBITS(re, &s->gb, 12); SKIP_BITS(re, &s->gb, 12);
                 i += run;
+                check_scantable_index(s, i);
                 j  = scantable[i];
                 if (level < 0) {
                     level = (-level * qscale * quant_matrix[j]) >> 4;
@@ -545,10 +554,6 @@ static inline int mpeg2_decode_block_intra(MpegEncContext *s, int16_t *block, in
                     level = (level * qscale * quant_matrix[j]) >> 4;
                 }
             }
-            if (i > 63) {
-                av_log(s->avctx, AV_LOG_ERROR, "ac-tex damaged at %d %d\n", s->mb_x, s->mb_y);
-                return -1;
-            }
 
             mismatch ^= level;
             block[j]  = level;
@@ -563,10 +568,10 @@ static inline int mpeg2_decode_block_intra(MpegEncContext *s, int16_t *block, in
 
 static inline int mpeg2_fast_decode_block_intra(MpegEncContext *s, int16_t *block, int n)
 {
-    int level, dc, diff, j, run;
+    int level, dc, diff, i, j, run;
     int component;
     RLTable *rl;
-    uint8_t * scantable = s->intra_scantable.permutated;
+    uint8_t * const scantable = s->intra_scantable.permutated;
     const uint16_t *quant_matrix;
     const int qscale = s->qscale;
 
@@ -585,6 +590,7 @@ static inline int mpeg2_fast_decode_block_intra(MpegEncContext *s, int16_t *bloc
     dc += diff;
     s->last_dc[component] = dc;
     block[0] = dc << (3 - s->intra_dc_precision);
+    i = 0;
     if (s->intra_vlc_format)
         rl = &ff_rl_mpeg2;
     else
@@ -600,8 +606,9 @@ static inline int mpeg2_fast_decode_block_intra(MpegEncContext *s, int16_t *bloc
             if (level == 127) {
                 break;
             } else if (level != 0) {
-                scantable += run;
-                j = *scantable;
+                i += run;
+                check_scantable_index(s, i);
+                j  = scantable[i];
                 level = (level * qscale * quant_matrix[j]) >> 4;
                 level = (level ^ SHOW_SBITS(re, &s->gb, 1)) - SHOW_SBITS(re, &s->gb, 1);
                 LAST_SKIP_BITS(re, &s->gb, 1);
@@ -610,8 +617,9 @@ static inline int mpeg2_fast_decode_block_intra(MpegEncContext *s, int16_t *bloc
                 run = SHOW_UBITS(re, &s->gb, 6) + 1; LAST_SKIP_BITS(re, &s->gb, 6);
                 UPDATE_CACHE(re, &s->gb);
                 level = SHOW_SBITS(re, &s->gb, 12); SKIP_BITS(re, &s->gb, 12);
-                scantable += run;
-                j = *scantable;
+                i += run;
+                check_scantable_index(s, i);
+                j  = scantable[i];
                 if (level < 0) {
                     level = (-level * qscale * quant_matrix[j]) >> 4;
                     level = -level;
@@ -625,7 +633,7 @@ static inline int mpeg2_fast_decode_block_intra(MpegEncContext *s, int16_t *bloc
         CLOSE_READER(re, &s->gb);
     }
 
-    s->block_last_index[n] = scantable - s->intra_scantable.permutated;
+    s->block_last_index[n] = i;
     return 0;
 }