1. 程式人生 > 其它 >深度學習計算框架綜述(十三)HVX 計算優化實踐—Concat 優化

深度學習計算框架綜述(十三)HVX 計算優化實踐—Concat 優化

技術標籤:深度學習計算框架綜述深度學習

主要有三種維度的Concat,Height、Width、Channel,如果是浮點計算,只需要進行資料拷貝即可,但是基於Quantization Aware Training的定點計算,需要判斷輸入和輸出的量化資訊是否匹配:

如果一致,則可以直接拷貝;如果不一致,則需要轉換量化資訊中的scale和bias值,由於Feature Map的資料範圍是限定的,即0-255,所以可以在init函式中將0-255對應的轉換值提前計算好,這

一步驟的程式碼如下(same_as_output_array_是一個數組,用於記錄每個input和output的量化資訊是否一致),然後在forward

函式中通過Hexagon DSP的查表指令vlut32,加速計算。

for (uint32_t input_id = 0; input_id < concat_callback_t_->common_param->in_num; ++input_id) {
  uint8_t *tmp_concat_lookup = concat_callback_t_->concat_lookup + input_id * 256;
  if (concat_quant_info_->bottom_scales[input_id] == concat_quant_info_->activation_scale &&
    concat_quant_info_->bottom_zps[input_id] == concat_quant_info_->activation_zp) {
    same_as_output_array_[input_id] = 1;
  } else {
    same_as_output_array_[input_id] = 0;
    const float inv_out_scale = 1.f / concat_quant_info_->activation_scale;
    const float scale = concat_quant_info_->bottom_scales[input_id] * inv_out_scale;
    const float bias = -concat_quant_info_->bottom_zps[input_id] * scale;
    for(int i = 0; i < 256; i++){
      const int32_t value = static_cast<int32_t>(round(i * scale + bias)) + concat_quant_info_->activation_zp;
      tmp_concat_lookup[i] = static_cast<uint8_t>(MAX(MIN(value, 255), 0));
    }
  }
}

對於Height方向的Concat,只需要判斷same_as_output_array_中對應的值是否為1:如果為1,則直接呼叫vmemcpy將input拷貝到output中,否則需要通過vlut指令查表獲取轉換後的值,C/Intrinsic實現程式碼如下:

static void concat_height_callback(void *data) {
  concat_callback_t *dptr = (concat_callback_t *)data;
  uint32_t thread_id = dspCV_atomic_inc_return((unsigned int *)(&(dptr->job_count))) - 1;
  uint32_t *channel_array = dptr->channel_array;
  uint8_t  *concat_lookup = dptr->concat_lookup;
  uint32_t concat_number = dptr->common_param->in_num;
  uint32_t thread_number = dptr->num_workers;//use 2 thread
  uint32_t every_thread_concat =  UP_DIV(concat_number, thread_number);;
  uint32_t input_begin_idx = thread_id * every_thread_concat;
  uint32_t input_end_idx = every_thread_concat * (thread_id + 1);
  input_end_idx =  MIN(input_end_idx, concat_number);
  uint32_t out_height_offset = 0;
  //compute out_offset of every thread
  for (int i=input_begin_idx-1; i >= 0 && i < concat_number; i--) {
    out_height_offset += dptr->height_array[i];
  }
  uint8_t *output_data_ptr = dptr->output + thread_id * dptr->next_out_width_depth * out_height_offset;
  uint32_t padded_in_width_depth = dptr->padded_out_depth * dptr->padded_in_width;
  for(; input_begin_idx < input_end_idx; input_begin_idx++) {
    //concat one input
    uint32_t input_id = input_begin_idx;
    uint8_t* input_data_ptr = dptr->inputs[input_id];
    uint32_t in_h = dptr->height_array[input_id];
    uint32_t input_d32_byte_size = dptr->next_out_width_depth * in_h;
    uint64_t L2FETCH_INPUT_REGISTER = (1ULL << 48) | ((uint64_t)(padded_in_width_depth) << 32) |
                                      ((uint64_t)(padded_in_width_depth) << 16) | 1ULL;
    L2FETCH(input_data_ptr, L2FETCH_INPUT_REGISTER);
    if (dptr->same_as_output_array[input_id]) {
      vmemcpy_asm(output_data_ptr, input_data_ptr, input_d32_byte_size);
      output_data_ptr += input_d32_byte_size;
    } else {
      uint8_t *concat_lookup_tmp = concat_lookup + input_id * 256;
      // byte shuffle table
      HVX_Vector luta = *(HVX_Vector *) concat_lookup_tmp;
      HVX_Vector lutb = *(HVX_Vector *) (concat_lookup_tmp + 128);
      HVX_Vector lut0 = Q6_Vb_vshuff_Vb(luta);
      HVX_Vector lut1 = Q6_Vb_vshuff_Vb(lutb);
      for (uint32_t h_idx=0; h_idx < in_h; h_idx++) {
        uint8_t *input_ptr_tmp = input_data_ptr + h_idx * padded_in_width_depth;
        if(h_idx < in_h - 1) {
          L2FETCH(input_data_ptr + (h_idx+1) * padded_in_width_depth, L2FETCH_INPUT_REGISTER);
        }
        for (int i = 0; i< (padded_in_width_depth >> 7); i++) {
          //deal a hvx data
          HVX_Vector vin = *(HVX_Vector *) input_ptr_tmp;
          HVX_Vector *vout = (HVX_Vector *) output_data_ptr;
          // look up value in table
          *vout = Q6_Vb_vlut32_VbVbI(vin, lut0, 0);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut0, 1);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut0, 2);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut0, 3);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 4);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 5);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 6);
          *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 7);
          input_ptr_tmp += 128;
          output_data_ptr += 128;
        }
      }
    }
  }
 
  dspCV_worker_pool_synctoken_jobdone(dptr->token);
}

下圖是vlut32的虛擬碼:

Width和Channel方向的Concat其實也類似,只是需要判斷Width不是4的整數倍以及Channel不是32的整數倍的情況。

Reference Code

Hexagon NN

op_div.c