obmc-aware 4mv
[libav.git] / libavcodec / snow.c
index 45483bd..4a96f3e 100644 (file)
@@ -3047,6 +3047,41 @@ static int get_dc(SnowContext *s, int mb_x, int mb_y, int plane_index){
     return clip(((ab<<6) + aa/2)/aa, 0, 255); //FIXME we shouldnt need cliping
 }
 
     return clip(((ab<<6) + aa/2)/aa, 0, 255); //FIXME we shouldnt need cliping
 }
 
+static inline int get_block_bits(SnowContext *s, int x, int y, int w){
+    const int b_stride = s->b_width << s->block_max_depth;
+    const int b_height = s->b_height<< s->block_max_depth;
+    int index= x + y*b_stride;
+    BlockNode *b     = &s->block[index];
+    BlockNode *left  = x ? &s->block[index-1] : &null_block;
+    BlockNode *top   = y ? &s->block[index-b_stride] : &null_block;
+    BlockNode *tl    = y && x ? &s->block[index-b_stride-1] : left;
+    BlockNode *tr    = y && x+w<b_stride ? &s->block[index-b_stride+w] : tl;
+    int dmx, dmy;
+//  int mx_context= av_log2(2*ABS(left->mx - top->mx));
+//  int my_context= av_log2(2*ABS(left->my - top->my));
+
+    if(x<0 || x>=b_stride || y>=b_height)
+        return 0;
+    dmx= b->mx - mid_pred(left->mx, top->mx, tr->mx);
+    dmy= b->my - mid_pred(left->my, top->my, tr->my);
+/*
+1            0      0
+01X          1-2    1
+001XX        3-6    2-3
+0001XXX      7-14   4-7
+00001XXXX   15-30   8-15
+*/
+//FIXME try accurate rate
+//FIXME intra and inter predictors if surrounding blocks arent the same type
+    if(b->type & BLOCK_INTRA){
+        return 3+2*( av_log2(2*ABS(left->color[0] - b->color[0]))
+                   + av_log2(2*ABS(left->color[1] - b->color[1]))
+                   + av_log2(2*ABS(left->color[2] - b->color[2])));
+    }else
+        return 2*(1 + av_log2(2*ABS(dmx))
+                    + av_log2(2*ABS(dmy))); //FIXME kill the 2* can be merged in lambda
+}
+
 static int get_block_rd(SnowContext *s, int mb_x, int mb_y, int plane_index, const uint8_t *obmc_edged){
     Plane *p= &s->plane[plane_index];
     const int block_size = MB_SIZE >> s->block_max_depth;
 static int get_block_rd(SnowContext *s, int mb_x, int mb_y, int plane_index, const uint8_t *obmc_edged){
     Plane *p= &s->plane[plane_index];
     const int block_size = MB_SIZE >> s->block_max_depth;
@@ -3108,41 +3143,75 @@ static int get_block_rd(SnowContext *s, int mb_x, int mb_y, int plane_index, con
  * .RXx.
  * rxx..
  */
  * .RXx.
  * rxx..
  */
-            int x= mb_x + (i&1) - (i>>1);
-            int y= mb_y + (i>>1);
-            int index= x + y*b_stride;
-            BlockNode *b     = &s->block[index];
-            BlockNode *left  = x ? &s->block[index-1] : &null_block;
-            BlockNode *top   = y ? &s->block[index-b_stride] : &null_block;
-            BlockNode *tl    = y && x ? &s->block[index-b_stride-1] : left;
-            BlockNode *tr    = y && x+1<b_stride ? &s->block[index-b_stride+1] : tl;
-            int dmx, dmy;
-//        int mx_context= av_log2(2*ABS(left->mx - top->mx));
-//        int my_context= av_log2(2*ABS(left->my - top->my));
-
-            if(x<0 || x>=b_stride || y>=b_height)
-                continue;
-            dmx= b->mx - mid_pred(left->mx, top->mx, tr->mx);
-            dmy= b->my - mid_pred(left->my, top->my, tr->my);
-/*
-1            0      0
-01X          1-2    1
-001XX        3-6    2-3
-0001XXX      7-14   4-7
-00001XXXX   15-30   8-15
-*/
-//FIXME try accurate rate
-//FIXME intra and inter predictors if surrounding blocks arent the same type
-            if(b->type & BLOCK_INTRA){
-                rate += 3+2*( av_log2(2*ABS(left->color[0] - b->color[0]))
-                            + av_log2(2*ABS(left->color[1] - b->color[1]))
-                            + av_log2(2*ABS(left->color[2] - b->color[2])));
-            }else
-                rate += 2*(1 + av_log2(2*ABS(dmx))
-                             + av_log2(2*ABS(dmy))); //FIXME kill the 2* can be merged in lambda
+            rate += get_block_bits(s, mb_x + (i&1) - (i>>1), mb_y + (i>>1), 1);
+        }
+    }
+    return distortion + rate*penalty_factor;
+}
+
+static int get_4block_rd(SnowContext *s, int mb_x, int mb_y, int plane_index){
+    int i, y2;
+    Plane *p= &s->plane[plane_index];
+    const int block_size = MB_SIZE >> s->block_max_depth;
+    const int block_w    = plane_index ? block_size/2 : block_size;
+    const uint8_t *obmc  = plane_index ? obmc_tab[s->block_max_depth+1] : obmc_tab[s->block_max_depth];
+    const int obmc_stride= plane_index ? block_size : 2*block_size;
+    const int ref_stride= s->current_picture.linesize[plane_index];
+    uint8_t *ref= s->   last_picture.data[plane_index];
+    uint8_t *dst= s->current_picture.data[plane_index];
+    uint8_t *src= s-> input_picture.data[plane_index];
+    const static DWTELEM zero_dst[4096]; //FIXME
+    const int b_stride = s->b_width << s->block_max_depth;
+    const int b_height = s->b_height<< s->block_max_depth;
+    const int w= p->width;
+    const int h= p->height;
+    int distortion= 0;
+    int rate= 0;
+    const int penalty_factor= get_penalty_factor(s->lambda, s->lambda2, s->avctx->me_cmp);
+
+    for(i=0; i<9; i++){
+        int mb_x2= mb_x + (i%3) - 1;
+        int mb_y2= mb_y + (i/3) - 1;
+        int x= block_w*mb_x2 + block_w/2;
+        int y= block_w*mb_y2 + block_w/2;
+
+        add_yblock(s, zero_dst, dst, ref, obmc, 
+                   x, y, block_w, block_w, w, h, /*dst_stride*/0, ref_stride, obmc_stride, mb_x2, mb_y2, 1, 1, plane_index);
+
+        //FIXME find a cleaner/simpler way to skip the outside stuff
+        for(y2= y; y2<0; y2++)
+            memcpy(dst + x + y2*ref_stride, src + x + y2*ref_stride, block_w);
+        for(y2= h; y2<y+block_w; y2++)
+            memcpy(dst + x + y2*ref_stride, src + x + y2*ref_stride, block_w);
+        if(x<0){
+            for(y2= y; y2<y+block_w; y2++)
+                memcpy(dst + x + y2*ref_stride, src + x + y2*ref_stride, -x);
         }
         }
+        if(x+block_w > w){
+            for(y2= y; y2<y+block_w; y2++)
+                memcpy(dst + w + y2*ref_stride, src + w + y2*ref_stride, x+block_w - w);
+        }
+
+        assert(block_w== 8 || block_w==16);
+        distortion += s->dsp.me_cmp[block_w==8](&s->m, src + x + y*ref_stride, dst + x + y*ref_stride, ref_stride, block_w);
     }
 
     }
 
+    if(plane_index==0){
+        BlockNode *b= &s->block[mb_x+mb_y*b_stride];
+        int merged= same_block(b,b+1) && same_block(b,b+b_stride) && same_block(b,b+b_stride+1);
+
+/* ..RRRr
+ * .RXXx.
+ * .RXXx.
+ * rxxx.
+ */
+        if(merged)
+            rate = get_block_bits(s, mb_x, mb_y, 2);
+        for(i=merged?4:0; i<9; i++){
+            static const int dxy[9][2] = {{0,0},{1,0},{0,1},{1,1},{2,0},{2,1},{-1,2},{0,2},{1,2}};
+            rate += get_block_bits(s, mb_x + dxy[i][0], mb_y + dxy[i][1], 1);
+        }
+    }
     return distortion + rate*penalty_factor;
 }
 
     return distortion + rate*penalty_factor;
 }
 
@@ -3190,6 +3259,42 @@ static always_inline int check_block_inter(SnowContext *s, int mb_x, int mb_y, i
     return check_block(s, mb_x, mb_y, p, intra, obmc_edged, best_rd);
 }
 
     return check_block(s, mb_x, mb_y, p, intra, obmc_edged, best_rd);
 }
 
+static always_inline int check_4block_inter(SnowContext *s, int mb_x, int mb_y, int p0, int p1, int *best_rd){
+    const int b_stride= s->b_width << s->block_max_depth;
+    BlockNode *block= &s->block[mb_x + mb_y * b_stride];
+    BlockNode backup[4]= {block[0], block[1], block[b_stride], block[b_stride+1]};
+    int rd, index, value;
+
+    assert(mb_x>=0 && mb_y>=0);
+    assert(mb_x<b_stride);
+    assert(((mb_x|mb_y)&1) == 0);
+
+    index= (p0 + 31*p1) & (ME_CACHE_SIZE-1);
+    value= s->me_cache_generation + (p0>>10) + (p1<<6);
+    if(s->me_cache[index] == value)
+        return 0;
+    s->me_cache[index]= value;
+
+    block->mx= p0;
+    block->my= p1;
+    block->type &= ~BLOCK_INTRA;
+    block[1]= block[b_stride]= block[b_stride+1]= *block;
+
+    rd= get_4block_rd(s, mb_x, mb_y, 0);
+
+//FIXME chroma
+    if(rd < *best_rd){
+        *best_rd= rd;
+        return 1;
+    }else{
+        block[0]= backup[0];
+        block[1]= backup[1];
+        block[b_stride]= backup[2];
+        block[b_stride+1]= backup[3];
+        return 0;
+    }
+}
+
 static void iterative_me(SnowContext *s){
     int pass, mb_x, mb_y;
     const int b_width = s->b_width  << s->block_max_depth;
 static void iterative_me(SnowContext *s){
     int pass, mb_x, mb_y;
     const int b_width = s->b_width  << s->block_max_depth;
@@ -3333,6 +3438,45 @@ static void iterative_me(SnowContext *s){
         if(!change)
             break;
     }
         if(!change)
             break;
     }
+
+    if(s->block_max_depth == 1){
+        int change= 0;
+        for(mb_y= 0; mb_y<b_height; mb_y+=2){
+            for(mb_x= 0; mb_x<b_width; mb_x+=2){
+                int dia_change, i, j;
+                int best_rd, init_rd;
+                const int index= mb_x + mb_y * b_stride;
+                BlockNode *b[4];
+
+                b[0]= &s->block[index];
+                b[1]= b[0]+1;
+                b[2]= b[0]+b_stride;
+                b[3]= b[2]+1;
+                if(same_block(b[0], b[1]) &&
+                   same_block(b[0], b[2]) &&
+                   same_block(b[0], b[3]))
+                    continue;
+
+                if(!s->me_cache_generation)
+                    memset(s->me_cache, 0, sizeof(s->me_cache));
+                s->me_cache_generation += 1<<22;
+
+                init_rd= best_rd= get_4block_rd(s, mb_x, mb_y, 0);
+
+                check_4block_inter(s, mb_x, mb_y,
+                                   (b[0]->mx + b[1]->mx + b[2]->mx + b[3]->mx + 2) >> 2,
+                                   (b[0]->my + b[1]->my + b[2]->my + b[3]->my + 2) >> 2, &best_rd);
+
+                for(i=0; i<4; i++)
+                    if(!(b[i]->type&BLOCK_INTRA))
+                        check_4block_inter(s, mb_x, mb_y, b[i]->mx, b[i]->my, &best_rd);
+
+                if(init_rd != best_rd)
+                    change++;
+            }
+        }
+        av_log(NULL, AV_LOG_ERROR, "pass:4mv changed:%d\n", change*4);
+    }
 }
 
 static void quantize(SnowContext *s, SubBand *b, DWTELEM *src, int stride, int bias){
 }
 
 static void quantize(SnowContext *s, SubBand *b, DWTELEM *src, int stride, int bias){