MS ATC Screen (aka MSS3) decoder
[libav.git] / libavcodec / mss3.c
1 /*
2 * Microsoft Screen 3 (aka Microsoft ATC Screen) decoder
3 * Copyright (c) 2012 Konstantin Shishkov
4 *
5 * This file is part of Libav.
6 *
7 * Libav is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU Lesser General Public
9 * License as published by the Free Software Foundation; either
10 * version 2.1 of the License, or (at your option) any later version.
11 *
12 * Libav is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public
18 * License along with Libav; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22 /**
23 * @file
24 * Microsoft Screen 3 (aka Microsoft ATC Screen) decoder
25 */
26
27 #include "avcodec.h"
28 #include "bytestream.h"
29
30 #define HEADER_SIZE 27
31
32 #define MODEL2_SCALE 13
33 #define MODEL_SCALE 15
34 #define MODEL256_SEC_SCALE 9
35
36 typedef struct Model2 {
37 int upd_val, till_rescale;
38 unsigned zero_freq, zero_weight;
39 unsigned total_freq, total_weight;
40 } Model2;
41
42 typedef struct Model {
43 int weights[16], freqs[16];
44 int num_syms;
45 int tot_weight;
46 int upd_val, max_upd_val, till_rescale;
47 } Model;
48
49 typedef struct Model256 {
50 int weights[256], freqs[256];
51 int tot_weight;
52 int secondary[68];
53 int sec_size;
54 int upd_val, max_upd_val, till_rescale;
55 } Model256;
56
57 #define RAC_BOTTOM 0x01000000
58 typedef struct RangeCoder {
59 const uint8_t *src, *src_end;
60
61 uint32_t range, low;
62 int got_error;
63 } RangeCoder;
64
65 enum BlockType {
66 FILL_BLOCK = 0,
67 IMAGE_BLOCK,
68 DCT_BLOCK,
69 HAAR_BLOCK,
70 SKIP_BLOCK
71 };
72
73 typedef struct BlockTypeContext {
74 int last_type;
75 Model bt_model[5];
76 } BlockTypeContext;
77
78 typedef struct FillBlockCoder {
79 int fill_val;
80 Model coef_model;
81 } FillBlockCoder;
82
83 typedef struct ImageBlockCoder {
84 Model256 esc_model, vec_entry_model;
85 Model vec_size_model;
86 Model vq_model[125];
87 } ImageBlockCoder;
88
89 typedef struct DCTBlockCoder {
90 int *prev_dc;
91 int prev_dc_stride;
92 int prev_dc_height;
93 int quality;
94 uint16_t qmat[64];
95 Model dc_model;
96 Model2 sign_model;
97 Model256 ac_model;
98 } DCTBlockCoder;
99
100 typedef struct HaarBlockCoder {
101 int quality, scale;
102 Model256 coef_model;
103 Model coef_hi_model;
104 } HaarBlockCoder;
105
106 typedef struct MSS3Context {
107 AVCodecContext *avctx;
108 AVFrame pic;
109
110 int got_error;
111 RangeCoder coder;
112 BlockTypeContext btype[3];
113 FillBlockCoder fill_coder[3];
114 ImageBlockCoder image_coder[3];
115 DCTBlockCoder dct_coder[3];
116 HaarBlockCoder haar_coder[3];
117
118 int dctblock[64];
119 int hblock[16 * 16];
120 } MSS3Context;
121
122 static const uint8_t mss3_luma_quant[64] = {
123 16, 11, 10, 16, 24, 40, 51, 61,
124 12, 12, 14, 19, 26, 58, 60, 55,
125 14, 13, 16, 24, 40, 57, 69, 56,
126 14, 17, 22, 29, 51, 87, 80, 62,
127 18, 22, 37, 56, 68, 109, 103, 77,
128 24, 35, 55, 64, 81, 104, 113, 92,
129 49, 64, 78, 87, 103, 121, 120, 101,
130 72, 92, 95, 98, 112, 100, 103, 99
131 };
132
133 static const uint8_t mss3_chroma_quant[64] = {
134 17, 18, 24, 47, 99, 99, 99, 99,
135 18, 21, 26, 66, 99, 99, 99, 99,
136 24, 26, 56, 99, 99, 99, 99, 99,
137 47, 66, 99, 99, 99, 99, 99, 99,
138 99, 99, 99, 99, 99, 99, 99, 99,
139 99, 99, 99, 99, 99, 99, 99, 99,
140 99, 99, 99, 99, 99, 99, 99, 99,
141 99, 99, 99, 99, 99, 99, 99, 99
142 };
143
144 const uint8_t zigzag_scan[64] = {
145 0, 1, 8, 16, 9, 2, 3, 10,
146 17, 24, 32, 25, 18, 11, 4, 5,
147 12, 19, 26, 33, 40, 48, 41, 34,
148 27, 20, 13, 6, 7, 14, 21, 28,
149 35, 42, 49, 56, 57, 50, 43, 36,
150 29, 22, 15, 23, 30, 37, 44, 51,
151 58, 59, 52, 45, 38, 31, 39, 46,
152 53, 60, 61, 54, 47, 55, 62, 63
153 };
154
155
156 static void model2_reset(Model2 *m)
157 {
158 m->zero_weight = 1;
159 m->total_weight = 2;
160 m->zero_freq = 0x1000;
161 m->total_freq = 0x2000;
162 m->upd_val = 4;
163 m->till_rescale = 4;
164 }
165
166 static void model2_update(Model2 *m, int bit)
167 {
168 unsigned scale;
169
170 if (!bit)
171 m->zero_weight++;
172 m->till_rescale--;
173 if (m->till_rescale)
174 return;
175
176 m->total_weight += m->upd_val;
177 if (m->total_weight > 0x2000) {
178 m->total_weight = (m->total_weight + 1) >> 1;
179 m->zero_weight = (m->zero_weight + 1) >> 1;
180 if (m->total_weight == m->zero_weight)
181 m->total_weight = m->zero_weight + 1;
182 }
183 m->upd_val = m->upd_val * 5 >> 2;
184 if (m->upd_val > 64)
185 m->upd_val = 64;
186 scale = 0x80000000u / m->total_weight;
187 m->zero_freq = m->zero_weight * scale >> 18;
188 m->total_freq = m->total_weight * scale >> 18;
189 m->till_rescale = m->upd_val;
190 }
191
192 static void model_update(Model *m, int val)
193 {
194 int i, sum = 0;
195 unsigned scale;
196
197 m->weights[val]++;
198 m->till_rescale--;
199 if (m->till_rescale)
200 return;
201 m->tot_weight += m->upd_val;
202
203 if (m->tot_weight > 0x8000) {
204 m->tot_weight = 0;
205 for (i = 0; i < m->num_syms; i++) {
206 m->weights[i] = (m->weights[i] + 1) >> 1;
207 m->tot_weight += m->weights[i];
208 }
209 }
210 scale = 0x80000000u / m->tot_weight;
211 for (i = 0; i < m->num_syms; i++) {
212 m->freqs[i] = sum * scale >> 16;
213 sum += m->weights[i];
214 }
215
216 m->upd_val = m->upd_val * 5 >> 2;
217 if (m->upd_val > m->max_upd_val)
218 m->upd_val = m->max_upd_val;
219 m->till_rescale = m->upd_val;
220 }
221
222 static void model_reset(Model *m)
223 {
224 int i;
225
226 m->tot_weight = 0;
227 for (i = 0; i < m->num_syms - 1; i++)
228 m->weights[i] = 1;
229 m->weights[m->num_syms - 1] = 0;
230
231 m->upd_val = m->num_syms;
232 m->till_rescale = 1;
233 model_update(m, m->num_syms - 1);
234 m->till_rescale =
235 m->upd_val = (m->num_syms + 6) >> 1;
236 }
237
238 static av_cold void model_init(Model *m, int num_syms)
239 {
240 m->num_syms = num_syms;
241 m->max_upd_val = 8 * num_syms + 48;
242
243 model_reset(m);
244 }
245
246 static void model256_update(Model256 *m, int val)
247 {
248 int i, sum = 0;
249 unsigned scale;
250 int send, sidx = 1;
251
252 m->weights[val]++;
253 m->till_rescale--;
254 if (m->till_rescale)
255 return;
256 m->tot_weight += m->upd_val;
257
258 if (m->tot_weight > 0x8000) {
259 m->tot_weight = 0;
260 for (i = 0; i < 256; i++) {
261 m->weights[i] = (m->weights[i] + 1) >> 1;
262 m->tot_weight += m->weights[i];
263 }
264 }
265 scale = 0x80000000u / m->tot_weight;
266 m->secondary[0] = 0;
267 for (i = 0; i < 256; i++) {
268 m->freqs[i] = sum * scale >> 16;
269 sum += m->weights[i];
270 send = m->freqs[i] >> MODEL256_SEC_SCALE;
271 while (sidx <= send)
272 m->secondary[sidx++] = i - 1;
273 }
274 while (sidx < m->sec_size)
275 m->secondary[sidx++] = 255;
276
277 m->upd_val = m->upd_val * 5 >> 2;
278 if (m->upd_val > m->max_upd_val)
279 m->upd_val = m->max_upd_val;
280 m->till_rescale = m->upd_val;
281 }
282
283 static void model256_reset(Model256 *m)
284 {
285 int i;
286
287 for (i = 0; i < 255; i++)
288 m->weights[i] = 1;
289 m->weights[255] = 0;
290
291 m->tot_weight = 0;
292 m->upd_val = 256;
293 m->till_rescale = 1;
294 model256_update(m, 255);
295 m->till_rescale =
296 m->upd_val = (256 + 6) >> 1;
297 }
298
299 static av_cold void model256_init(Model256 *m)
300 {
301 m->max_upd_val = 8 * 256 + 48;
302 m->sec_size = (1 << 6) + 2;
303
304 model256_reset(m);
305 }
306
307 static void rac_init(RangeCoder *c, const uint8_t *src, int size)
308 {
309 int i;
310
311 c->src = src;
312 c->src_end = src + size;
313 c->low = 0;
314 for (i = 0; i < FFMIN(size, 4); i++)
315 c->low = (c->low << 8) | *c->src++;
316 c->range = 0xFFFFFFFF;
317 c->got_error = 0;
318 }
319
320 static void rac_normalise(RangeCoder *c)
321 {
322 for (;;) {
323 c->range <<= 8;
324 c->low <<= 8;
325 if (c->src < c->src_end) {
326 c->low |= *c->src++;
327 } else if (!c->low) {
328 c->got_error = 1;
329 return;
330 }
331 if (c->range >= RAC_BOTTOM)
332 return;
333 }
334 }
335
336 static int rac_get_bit(RangeCoder *c)
337 {
338 int bit;
339
340 c->range >>= 1;
341
342 bit = (c->range <= c->low);
343 if (bit)
344 c->low -= c->range;
345
346 if (c->range < RAC_BOTTOM)
347 rac_normalise(c);
348
349 return bit;
350 }
351
352 static int rac_get_bits(RangeCoder *c, int nbits)
353 {
354 int val;
355
356 c->range >>= nbits;
357 val = c->low / c->range;
358 c->low -= c->range * val;
359
360 if (c->range < RAC_BOTTOM)
361 rac_normalise(c);
362
363 return val;
364 }
365
366 static int rac_get_model2_sym(RangeCoder *c, Model2 *m)
367 {
368 int bit, helper;
369
370 helper = m->zero_freq * (c->range >> MODEL2_SCALE);
371 bit = (c->low >= helper);
372 if (bit) {
373 c->low -= helper;
374 c->range -= helper;
375 } else {
376 c->range = helper;
377 }
378
379 if (c->range < RAC_BOTTOM)
380 rac_normalise(c);
381
382 model2_update(m, bit);
383
384 return bit;
385 }
386
387 static int rac_get_model_sym(RangeCoder *c, Model *m)
388 {
389 int prob, prob2, helper, val;
390 int end, end2;
391
392 prob = 0;
393 prob2 = c->range;
394 c->range >>= MODEL_SCALE;
395 val = 0;
396 end = m->num_syms >> 1;
397 end2 = m->num_syms;
398 do {
399 helper = m->freqs[end] * c->range;
400 if (helper <= c->low) {
401 val = end;
402 prob = helper;
403 } else {
404 end2 = end;
405 prob2 = helper;
406 }
407 end = (end2 + val) >> 1;
408 } while (end != val);
409 c->low -= prob;
410 c->range = prob2 - prob;
411 if (c->range < RAC_BOTTOM)
412 rac_normalise(c);
413
414 model_update(m, val);
415
416 return val;
417 }
418
419 static int rac_get_model256_sym(RangeCoder *c, Model256 *m)
420 {
421 int prob, prob2, helper, val;
422 int start, end;
423 int ssym;
424
425 prob2 = c->range;
426 c->range >>= MODEL_SCALE;
427
428 helper = c->low / c->range;
429 ssym = helper >> MODEL256_SEC_SCALE;
430 val = m->secondary[ssym];
431
432 end = start = m->secondary[ssym + 1] + 1;
433 while (end > val + 1) {
434 ssym = (end + val) >> 1;
435 if (m->freqs[ssym] <= helper) {
436 end = start;
437 val = ssym;
438 } else {
439 end = (end + val) >> 1;
440 start = ssym;
441 }
442 }
443 prob = m->freqs[val] * c->range;
444 if (val != 255)
445 prob2 = m->freqs[val + 1] * c->range;
446
447 c->low -= prob;
448 c->range = prob2 - prob;
449 if (c->range < RAC_BOTTOM)
450 rac_normalise(c);
451
452 model256_update(m, val);
453
454 return val;
455 }
456
457 static int decode_block_type(RangeCoder *c, BlockTypeContext *bt)
458 {
459 bt->last_type = rac_get_model_sym(c, &bt->bt_model[bt->last_type]);
460
461 return bt->last_type;
462 }
463
464 static int decode_coeff(RangeCoder *c, Model *m)
465 {
466 int val, sign;
467
468 val = rac_get_model_sym(c, m);
469 if (val) {
470 sign = rac_get_bit(c);
471 if (val > 1) {
472 val--;
473 val = (1 << val) + rac_get_bits(c, val);
474 }
475 if (!sign)
476 val = -val;
477 }
478
479 return val;
480 }
481
482 static void decode_fill_block(RangeCoder *c, FillBlockCoder *fc,
483 uint8_t *dst, int stride, int block_size)
484 {
485 int i;
486
487 fc->fill_val += decode_coeff(c, &fc->coef_model);
488
489 for (i = 0; i < block_size; i++, dst += stride)
490 memset(dst, fc->fill_val, block_size);
491 }
492
493 static void decode_image_block(RangeCoder *c, ImageBlockCoder *ic,
494 uint8_t *dst, int stride, int block_size)
495 {
496 int i, j;
497 int vec_size;
498 int vec[4];
499 int prev_line[16];
500 int A, B, C;
501
502 vec_size = rac_get_model_sym(c, &ic->vec_size_model) + 2;
503 for (i = 0; i < vec_size; i++)
504 vec[i] = rac_get_model256_sym(c, &ic->vec_entry_model);
505 for (; i < 4; i++)
506 vec[i] = 0;
507 memset(prev_line, 0, sizeof(prev_line));
508
509 for (j = 0; j < block_size; j++) {
510 A = 0;
511 B = 0;
512 for (i = 0; i < block_size; i++) {
513 C = B;
514 B = prev_line[i];
515 A = rac_get_model_sym(c, &ic->vq_model[A + B * 5 + C * 25]);
516
517 prev_line[i] = A;
518 if (A < 4)
519 dst[i] = vec[A];
520 else
521 dst[i] = rac_get_model256_sym(c, &ic->esc_model);
522 }
523 dst += stride;
524 }
525 }
526
527 static int decode_dct(RangeCoder *c, DCTBlockCoder *bc, int *block,
528 int bx, int by)
529 {
530 int skip, val, sign, pos = 1, zz_pos, dc;
531 int blk_pos = bx + by * bc->prev_dc_stride;
532
533 memset(block, 0, sizeof(*block) * 64);
534
535 dc = decode_coeff(c, &bc->dc_model);
536 if (by) {
537 if (bx) {
538 int l, tl, t;
539
540 l = bc->prev_dc[blk_pos - 1];
541 tl = bc->prev_dc[blk_pos - 1 - bc->prev_dc_stride];
542 t = bc->prev_dc[blk_pos - bc->prev_dc_stride];
543
544 if (FFABS(t - tl) <= FFABS(l - tl))
545 dc += l;
546 else
547 dc += t;
548 } else {
549 dc += bc->prev_dc[blk_pos - bc->prev_dc_stride];
550 }
551 } else if (bx) {
552 dc += bc->prev_dc[bx - 1];
553 }
554 bc->prev_dc[blk_pos] = dc;
555 block[0] = dc * bc->qmat[0];
556
557 while (pos < 64) {
558 val = rac_get_model256_sym(c, &bc->ac_model);
559 if (!val)
560 return 0;
561 if (val == 0xF0) {
562 pos += 16;
563 continue;
564 }
565 skip = val >> 4;
566 val = val & 0xF;
567 if (!val)
568 return -1;
569 pos += skip;
570 if (pos >= 64)
571 return -1;
572
573 sign = rac_get_model2_sym(c, &bc->sign_model);
574 if (val > 1) {
575 val--;
576 val = (1 << val) + rac_get_bits(c, val);
577 }
578 if (!sign)
579 val = -val;
580
581 zz_pos = zigzag_scan[pos];
582 block[zz_pos] = val * bc->qmat[zz_pos];
583 pos++;
584 }
585
586 return pos == 64 ? 0 : -1;
587 }
588
589 #define DCT_TEMPLATE(blk, step, SOP, shift) \
590 const int t0 = -39409 * blk[7 * step] - 58980 * blk[1 * step]; \
591 const int t1 = 39410 * blk[1 * step] - 58980 * blk[7 * step]; \
592 const int t2 = -33410 * blk[5 * step] - 167963 * blk[3 * step]; \
593 const int t3 = 33410 * blk[3 * step] - 167963 * blk[5 * step]; \
594 const int t4 = blk[3 * step] + blk[7 * step]; \
595 const int t5 = blk[1 * step] + blk[5 * step]; \
596 const int t6 = 77062 * t4 + 51491 * t5; \
597 const int t7 = 77062 * t5 - 51491 * t4; \
598 const int t8 = 35470 * blk[2 * step] - 85623 * blk[6 * step]; \
599 const int t9 = 35470 * blk[6 * step] + 85623 * blk[2 * step]; \
600 const int tA = SOP(blk[0 * step] - blk[4 * step]); \
601 const int tB = SOP(blk[0 * step] + blk[4 * step]); \
602 \
603 blk[0 * step] = ( t1 + t6 + t9 + tB) >> shift; \
604 blk[1 * step] = ( t3 + t7 + t8 + tA) >> shift; \
605 blk[2 * step] = ( t2 + t6 - t8 + tA) >> shift; \
606 blk[3 * step] = ( t0 + t7 - t9 + tB) >> shift; \
607 blk[4 * step] = (-(t0 + t7) - t9 + tB) >> shift; \
608 blk[5 * step] = (-(t2 + t6) - t8 + tA) >> shift; \
609 blk[6 * step] = (-(t3 + t7) + t8 + tA) >> shift; \
610 blk[7 * step] = (-(t1 + t6) + t9 + tB) >> shift; \
611
612 #define SOP_ROW(a) ((a) << 16) + 0x2000
613 #define SOP_COL(a) ((a + 32) << 16)
614
615 static void dct_put(uint8_t *dst, int stride, int *block)
616 {
617 int i, j;
618 int *ptr;
619
620 ptr = block;
621 for (i = 0; i < 8; i++) {
622 DCT_TEMPLATE(ptr, 1, SOP_ROW, 13);
623 ptr += 8;
624 }
625
626 ptr = block;
627 for (i = 0; i < 8; i++) {
628 DCT_TEMPLATE(ptr, 8, SOP_COL, 22);
629 ptr++;
630 }
631
632 ptr = block;
633 for (j = 0; j < 8; j++) {
634 for (i = 0; i < 8; i++)
635 dst[i] = av_clip_uint8(ptr[i] + 128);
636 dst += stride;
637 ptr += 8;
638 }
639 }
640
641 static void decode_dct_block(RangeCoder *c, DCTBlockCoder *bc,
642 uint8_t *dst, int stride, int block_size,
643 int *block, int mb_x, int mb_y)
644 {
645 int i, j;
646 int bx, by;
647 int nblocks = block_size >> 3;
648
649 bx = mb_x * nblocks;
650 by = mb_y * nblocks;
651
652 for (j = 0; j < nblocks; j++) {
653 for (i = 0; i < nblocks; i++) {
654 if (decode_dct(c, bc, block, bx + i, by + j)) {
655 c->got_error = 1;
656 return;
657 }
658 dct_put(dst + i * 8, stride, block);
659 }
660 dst += 8 * stride;
661 }
662 }
663
664 static void decode_haar_block(RangeCoder *c, HaarBlockCoder *hc,
665 uint8_t *dst, int stride, int block_size,
666 int *block)
667 {
668 const int hsize = block_size >> 1;
669 int A, B, C, D, t1, t2, t3, t4;
670 int i, j;
671
672 for (j = 0; j < block_size; j++) {
673 for (i = 0; i < block_size; i++) {
674 if (i < hsize && j < hsize)
675 block[i] = rac_get_model256_sym(c, &hc->coef_model);
676 else
677 block[i] = decode_coeff(c, &hc->coef_hi_model);
678 block[i] *= hc->scale;
679 }
680 block += block_size;
681 }
682 block -= block_size * block_size;
683
684 for (j = 0; j < hsize; j++) {
685 for (i = 0; i < hsize; i++) {
686 A = block[i];
687 B = block[i + hsize];
688 C = block[i + hsize * block_size];
689 D = block[i + hsize * block_size + hsize];
690
691 t1 = A - B;
692 t2 = C - D;
693 t3 = A + B;
694 t4 = C + D;
695 dst[i * 2] = av_clip_uint8(t1 - t2);
696 dst[i * 2 + stride] = av_clip_uint8(t1 + t2);
697 dst[i * 2 + 1] = av_clip_uint8(t3 - t4);
698 dst[i * 2 + 1 + stride] = av_clip_uint8(t3 + t4);
699 }
700 block += block_size;
701 dst += stride * 2;
702 }
703 }
704
705 static void gen_quant_mat(uint16_t *qmat, const uint8_t *ref, float scale)
706 {
707 int i;
708
709 for (i = 0; i < 64; i++)
710 qmat[i] = (uint16_t)(ref[i] * scale + 50.0) / 100;
711 }
712
713 static void reset_coders(MSS3Context *ctx, int quality)
714 {
715 int i, j;
716
717 for (i = 0; i < 3; i++) {
718 ctx->btype[i].last_type = SKIP_BLOCK;
719 for (j = 0; j < 5; j++)
720 model_reset(&ctx->btype[i].bt_model[j]);
721 ctx->fill_coder[i].fill_val = 0;
722 model_reset(&ctx->fill_coder[i].coef_model);
723 model256_reset(&ctx->image_coder[i].esc_model);
724 model256_reset(&ctx->image_coder[i].vec_entry_model);
725 model_reset(&ctx->image_coder[i].vec_size_model);
726 for (j = 0; j < 125; j++)
727 model_reset(&ctx->image_coder[i].vq_model[j]);
728 if (ctx->dct_coder[i].quality != quality) {
729 float scale;
730 ctx->dct_coder[i].quality = quality;
731 if (quality > 50)
732 scale = 200.0f - 2 * quality;
733 else
734 scale = 5000.0f / quality;
735 gen_quant_mat(ctx->dct_coder[i].qmat,
736 i ? mss3_chroma_quant : mss3_luma_quant,
737 scale);
738 }
739 memset(ctx->dct_coder[i].prev_dc, 0,
740 sizeof(*ctx->dct_coder[i].prev_dc) *
741 ctx->dct_coder[i].prev_dc_stride *
742 ctx->dct_coder[i].prev_dc_height);
743 model_reset(&ctx->dct_coder[i].dc_model);
744 model2_reset(&ctx->dct_coder[i].sign_model);
745 model256_reset(&ctx->dct_coder[i].ac_model);
746 if (ctx->haar_coder[i].quality != quality) {
747 ctx->haar_coder[i].quality = quality;
748 ctx->haar_coder[i].scale = 17 - 7 * quality / 50;
749 }
750 model_reset(&ctx->haar_coder[i].coef_hi_model);
751 model256_reset(&ctx->haar_coder[i].coef_model);
752 }
753 }
754
755 static av_cold void init_coders(MSS3Context *ctx)
756 {
757 int i, j;
758
759 for (i = 0; i < 3; i++) {
760 for (j = 0; j < 5; j++)
761 model_init(&ctx->btype[i].bt_model[j], 5);
762 model_init(&ctx->fill_coder[i].coef_model, 12);
763 model256_init(&ctx->image_coder[i].esc_model);
764 model256_init(&ctx->image_coder[i].vec_entry_model);
765 model_init(&ctx->image_coder[i].vec_size_model, 3);
766 for (j = 0; j < 125; j++)
767 model_init(&ctx->image_coder[i].vq_model[j], 5);
768 model_init(&ctx->dct_coder[i].dc_model, 12);
769 model256_init(&ctx->dct_coder[i].ac_model);
770 model_init(&ctx->haar_coder[i].coef_hi_model, 12);
771 model256_init(&ctx->haar_coder[i].coef_model);
772 }
773 }
774
775 static int mss3_decode_frame(AVCodecContext *avctx, void *data, int *data_size,
776 AVPacket *avpkt)
777 {
778 const uint8_t *buf = avpkt->data;
779 int buf_size = avpkt->size;
780 MSS3Context *c = avctx->priv_data;
781 RangeCoder *acoder = &c->coder;
782 GetByteContext gb;
783 uint8_t *dst[3];
784 int dec_width, dec_height, dec_x, dec_y, quality, keyframe;
785 int x, y, i, mb_width, mb_height, blk_size, btype;
786 int ret;
787
788 if (buf_size < HEADER_SIZE) {
789 av_log(avctx, AV_LOG_ERROR,
790 "Frame should have at least %d bytes, got %d instead\n",
791 HEADER_SIZE, buf_size);
792 return AVERROR_INVALIDDATA;
793 }
794
795 bytestream2_init(&gb, buf, buf_size);
796 keyframe = bytestream2_get_be32(&gb);
797 if (keyframe & ~0x301) {
798 av_log(avctx, AV_LOG_ERROR, "Invalid frame type %X\n", keyframe);
799 return AVERROR_INVALIDDATA;
800 }
801 keyframe = !(keyframe & 1);
802 bytestream2_skip(&gb, 6);
803 dec_x = bytestream2_get_be16(&gb);
804 dec_y = bytestream2_get_be16(&gb);
805 dec_width = bytestream2_get_be16(&gb);
806 dec_height = bytestream2_get_be16(&gb);
807
808 if (dec_x + dec_width > avctx->width ||
809 dec_y + dec_height > avctx->height ||
810 (dec_width | dec_height) & 0xF) {
811 av_log(avctx, AV_LOG_ERROR, "Invalid frame dimensions %dx%d +%d,%d\n",
812 dec_width, dec_height, dec_x, dec_y);
813 return AVERROR_INVALIDDATA;
814 }
815 bytestream2_skip(&gb, 4);
816 quality = bytestream2_get_byte(&gb);
817 if (quality < 1 || quality > 100) {
818 av_log(avctx, AV_LOG_ERROR, "Invalid quality setting %d\n", quality);
819 return AVERROR_INVALIDDATA;
820 }
821 bytestream2_skip(&gb, 4);
822
823 if (keyframe && !bytestream2_get_bytes_left(&gb)) {
824 av_log(avctx, AV_LOG_ERROR, "Keyframe without data found\n");
825 return AVERROR_INVALIDDATA;
826 }
827 if (!keyframe && c->got_error)
828 return buf_size;
829 c->got_error = 0;
830
831 c->pic.reference = 3;
832 c->pic.buffer_hints = FF_BUFFER_HINTS_VALID | FF_BUFFER_HINTS_PRESERVE |
833 FF_BUFFER_HINTS_REUSABLE;
834 if ((ret = avctx->reget_buffer(avctx, &c->pic)) < 0) {
835 av_log(avctx, AV_LOG_ERROR, "reget_buffer() failed\n");
836 return ret;
837 }
838 c->pic.key_frame = keyframe;
839 c->pic.pict_type = keyframe ? AV_PICTURE_TYPE_I : AV_PICTURE_TYPE_P;
840 if (!bytestream2_get_bytes_left(&gb)) {
841 *data_size = sizeof(AVFrame);
842 *(AVFrame*)data = c->pic;
843
844 return buf_size;
845 }
846
847 reset_coders(c, quality);
848
849 rac_init(acoder, buf + HEADER_SIZE, buf_size - HEADER_SIZE);
850
851 mb_width = dec_width >> 4;
852 mb_height = dec_height >> 4;
853 dst[0] = c->pic.data[0] + dec_x + dec_y * c->pic.linesize[0];
854 dst[1] = c->pic.data[1] + dec_x / 2 + (dec_y / 2) * c->pic.linesize[1];
855 dst[2] = c->pic.data[2] + dec_x / 2 + (dec_y / 2) * c->pic.linesize[2];
856 for (y = 0; y < mb_height; y++) {
857 for (x = 0; x < mb_width; x++) {
858 for (i = 0; i < 3; i++) {
859 blk_size = 8 << !i;
860
861 btype = decode_block_type(acoder, c->btype + i);
862 switch (btype) {
863 case FILL_BLOCK:
864 decode_fill_block(acoder, c->fill_coder + i,
865 dst[i] + x * blk_size,
866 c->pic.linesize[i], blk_size);
867 break;
868 case IMAGE_BLOCK:
869 decode_image_block(acoder, c->image_coder + i,
870 dst[i] + x * blk_size,
871 c->pic.linesize[i], blk_size);
872 break;
873 case DCT_BLOCK:
874 decode_dct_block(acoder, c->dct_coder + i,
875 dst[i] + x * blk_size,
876 c->pic.linesize[i], blk_size,
877 c->dctblock, x, y);
878 break;
879 case HAAR_BLOCK:
880 decode_haar_block(acoder, c->haar_coder + i,
881 dst[i] + x * blk_size,
882 c->pic.linesize[i], blk_size,
883 c->hblock);
884 break;
885 }
886 if (c->got_error || acoder->got_error) {
887 av_log(avctx, AV_LOG_ERROR, "Error decoding block %d,%d\n",
888 x, y);
889 c->got_error = 1;
890 return AVERROR_INVALIDDATA;
891 }
892 }
893 }
894 dst[0] += c->pic.linesize[0] * 16;
895 dst[1] += c->pic.linesize[1] * 8;
896 dst[2] += c->pic.linesize[2] * 8;
897 }
898
899 *data_size = sizeof(AVFrame);
900 *(AVFrame*)data = c->pic;
901
902 return buf_size;
903 }
904
905 static av_cold int mss3_decode_init(AVCodecContext *avctx)
906 {
907 MSS3Context * const c = avctx->priv_data;
908 int i;
909
910 c->avctx = avctx;
911
912 if ((avctx->width & 0xF) || (avctx->height & 0xF)) {
913 av_log(avctx, AV_LOG_ERROR,
914 "Image dimensions should be a multiple of 16.\n");
915 return AVERROR_INVALIDDATA;
916 }
917
918 c->got_error = 0;
919 for (i = 0; i < 3; i++) {
920 int b_width = avctx->width >> (2 + !!i);
921 int b_height = avctx->height >> (2 + !!i);
922 c->dct_coder[i].prev_dc_stride = b_width;
923 c->dct_coder[i].prev_dc_height = b_height;
924 c->dct_coder[i].prev_dc = av_malloc(sizeof(*c->dct_coder[i].prev_dc) *
925 b_width * b_height);
926 if (!c->dct_coder[i].prev_dc) {
927 av_log(avctx, AV_LOG_ERROR, "Cannot allocate buffer\n");
928 while (i >= 0) {
929 av_freep(&c->dct_coder[i].prev_dc);
930 i--;
931 }
932 return AVERROR(ENOMEM);
933 }
934 }
935
936 avctx->pix_fmt = PIX_FMT_YUV420P;
937 avctx->coded_frame = &c->pic;
938
939 init_coders(c);
940
941 return 0;
942 }
943
944 static av_cold int mss3_decode_end(AVCodecContext *avctx)
945 {
946 MSS3Context * const c = avctx->priv_data;
947 int i;
948
949 if (c->pic.data[0])
950 avctx->release_buffer(avctx, &c->pic);
951 for (i = 0; i < 3; i++)
952 av_freep(&c->dct_coder[i].prev_dc);
953
954 return 0;
955 }
956
957 AVCodec ff_msa1_decoder = {
958 .name = "msa1",
959 .type = AVMEDIA_TYPE_VIDEO,
960 .id = CODEC_ID_MSA1,
961 .priv_data_size = sizeof(MSS3Context),
962 .init = mss3_decode_init,
963 .close = mss3_decode_end,
964 .decode = mss3_decode_frame,
965 .capabilities = CODEC_CAP_DR1,
966 .long_name = NULL_IF_CONFIG_SMALL("MS ATC Screen"),
967 };