深度學習計算框架綜述(十三)HVX 計算優化實踐—Concat 優化
阿新 • • 發佈:2021-01-10
技術標籤:深度學習計算框架綜述深度學習
主要有三種維度的Concat,Height、Width、Channel,如果是浮點計算,只需要進行資料拷貝即可,但是基於Quantization Aware Training的定點計算,需要判斷輸入和輸出的量化資訊是否匹配:
如果一致,則可以直接拷貝;如果不一致,則需要轉換量化資訊中的scale和bias值,由於Feature Map的資料範圍是限定的,即0-255,所以可以在init函式中將0-255對應的轉換值提前計算好,這
一步驟的程式碼如下(same_as_output_array_是一個數組,用於記錄每個input和output的量化資訊是否一致),然後在forward
for (uint32_t input_id = 0; input_id < concat_callback_t_->common_param->in_num; ++input_id) { uint8_t *tmp_concat_lookup = concat_callback_t_->concat_lookup + input_id * 256; if (concat_quant_info_->bottom_scales[input_id] == concat_quant_info_->activation_scale && concat_quant_info_->bottom_zps[input_id] == concat_quant_info_->activation_zp) { same_as_output_array_[input_id] = 1; } else { same_as_output_array_[input_id] = 0; const float inv_out_scale = 1.f / concat_quant_info_->activation_scale; const float scale = concat_quant_info_->bottom_scales[input_id] * inv_out_scale; const float bias = -concat_quant_info_->bottom_zps[input_id] * scale; for(int i = 0; i < 256; i++){ const int32_t value = static_cast<int32_t>(round(i * scale + bias)) + concat_quant_info_->activation_zp; tmp_concat_lookup[i] = static_cast<uint8_t>(MAX(MIN(value, 255), 0)); } } }
對於Height方向的Concat,只需要判斷same_as_output_array_中對應的值是否為1:如果為1,則直接呼叫vmemcpy將input拷貝到output中,否則需要通過vlut指令查表獲取轉換後的值,C/Intrinsic實現程式碼如下:
static void concat_height_callback(void *data) { concat_callback_t *dptr = (concat_callback_t *)data; uint32_t thread_id = dspCV_atomic_inc_return((unsigned int *)(&(dptr->job_count))) - 1; uint32_t *channel_array = dptr->channel_array; uint8_t *concat_lookup = dptr->concat_lookup; uint32_t concat_number = dptr->common_param->in_num; uint32_t thread_number = dptr->num_workers;//use 2 thread uint32_t every_thread_concat = UP_DIV(concat_number, thread_number);; uint32_t input_begin_idx = thread_id * every_thread_concat; uint32_t input_end_idx = every_thread_concat * (thread_id + 1); input_end_idx = MIN(input_end_idx, concat_number); uint32_t out_height_offset = 0; //compute out_offset of every thread for (int i=input_begin_idx-1; i >= 0 && i < concat_number; i--) { out_height_offset += dptr->height_array[i]; } uint8_t *output_data_ptr = dptr->output + thread_id * dptr->next_out_width_depth * out_height_offset; uint32_t padded_in_width_depth = dptr->padded_out_depth * dptr->padded_in_width; for(; input_begin_idx < input_end_idx; input_begin_idx++) { //concat one input uint32_t input_id = input_begin_idx; uint8_t* input_data_ptr = dptr->inputs[input_id]; uint32_t in_h = dptr->height_array[input_id]; uint32_t input_d32_byte_size = dptr->next_out_width_depth * in_h; uint64_t L2FETCH_INPUT_REGISTER = (1ULL << 48) | ((uint64_t)(padded_in_width_depth) << 32) | ((uint64_t)(padded_in_width_depth) << 16) | 1ULL; L2FETCH(input_data_ptr, L2FETCH_INPUT_REGISTER); if (dptr->same_as_output_array[input_id]) { vmemcpy_asm(output_data_ptr, input_data_ptr, input_d32_byte_size); output_data_ptr += input_d32_byte_size; } else { uint8_t *concat_lookup_tmp = concat_lookup + input_id * 256; // byte shuffle table HVX_Vector luta = *(HVX_Vector *) concat_lookup_tmp; HVX_Vector lutb = *(HVX_Vector *) (concat_lookup_tmp + 128); HVX_Vector lut0 = Q6_Vb_vshuff_Vb(luta); HVX_Vector lut1 = Q6_Vb_vshuff_Vb(lutb); for (uint32_t h_idx=0; h_idx < in_h; h_idx++) { uint8_t *input_ptr_tmp = input_data_ptr + h_idx * padded_in_width_depth; if(h_idx < in_h - 1) { L2FETCH(input_data_ptr + (h_idx+1) * padded_in_width_depth, L2FETCH_INPUT_REGISTER); } for (int i = 0; i< (padded_in_width_depth >> 7); i++) { //deal a hvx data HVX_Vector vin = *(HVX_Vector *) input_ptr_tmp; HVX_Vector *vout = (HVX_Vector *) output_data_ptr; // look up value in table *vout = Q6_Vb_vlut32_VbVbI(vin, lut0, 0); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut0, 1); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut0, 2); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut0, 3); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 4); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 5); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 6); *vout = Q6_Vb_vlut32or_VbVbVbI(*vout, vin, lut1, 7); input_ptr_tmp += 128; output_data_ptr += 128; } } } } dspCV_worker_pool_synctoken_jobdone(dptr->token); }
下圖是vlut32的虛擬碼:
Width和Channel方向的Concat其實也類似,只是需要判斷Width不是4的整數倍以及Channel不是32的整數倍的情況。
Reference Code
Hexagon NN
op_div.c