backprop_kernel.cl 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #define THREADS 256
  2. #define WIDTH 16
  3. #define HEIGHT 16
  4. #define ETA 0.3f
  5. #define MOMENTUM 0.3f
  6. #ifndef _BACKPROP_CUDA_KERNEL_H_
  7. #define _BACKPROP_CUDA_KERNEL_H_
  8. #define WM(i, j) weight_matrix[(j) + (i) * WIDTH]
  9. __kernel void
  10. bpnn_layerforward_ocl(__global float *input_cuda,
  11. __global float *output_hidden_cuda,
  12. __global float *input_hidden_cuda,
  13. __global float *hidden_partial_sum,
  14. __local float *input_node,
  15. __local float *weight_matrix,
  16. int in,
  17. int hid)
  18. {
  19. int by = get_group_id(1);
  20. int tx = get_local_id(0);
  21. int ty = get_local_id(1);
  22. int index = ( hid + 1 ) * HEIGHT * by + ( hid + 1 ) * ty + tx + 1 + ( hid + 1 ) ;
  23. int index_in = HEIGHT * by + ty + 1;
  24. if ( tx == 0 )
  25. input_node[ty] = input_cuda[index_in] ;
  26. barrier(CLK_LOCAL_MEM_FENCE);
  27. weight_matrix[ty * WIDTH + tx] = input_hidden_cuda[index];
  28. barrier(CLK_LOCAL_MEM_FENCE);
  29. weight_matrix[ty * WIDTH + tx]= weight_matrix[ty * WIDTH + tx] * input_node[ty];
  30. barrier(CLK_LOCAL_MEM_FENCE);
  31. for ( int i = 1 ; i <= HEIGHT ; i=i*2){
  32. //for ( int i = 1 ; i <= 4 ; i++){
  33. int power_two = i;
  34. //int power_two = 2 << (i - 1);
  35. if( ty % power_two == 0 )
  36. weight_matrix[ty * WIDTH + tx]= weight_matrix[ty * WIDTH + tx] + weight_matrix[(ty + power_two/2)* WIDTH + tx];
  37. barrier(CLK_LOCAL_MEM_FENCE);
  38. }
  39. input_hidden_cuda[index] = weight_matrix[ty * WIDTH + tx];
  40. barrier(CLK_LOCAL_MEM_FENCE);
  41. if ( tx == 0 ) {
  42. hidden_partial_sum[by * hid + ty] = weight_matrix[tx* WIDTH + ty];
  43. }
  44. }
  45. __kernel void bpnn_adjust_weights_ocl( __global float * delta,
  46. int hid,
  47. __global float * ly,
  48. int in,
  49. __global float * w,
  50. __global float * oldw)
  51. {
  52. int by = get_group_id(1);
  53. int tx = get_local_id(0);
  54. int ty = get_local_id(1);
  55. int index = ( hid + 1 ) * HEIGHT * by + ( hid + 1 ) * ty + tx + 1 + ( hid + 1 ) ;
  56. int index_y = HEIGHT * by + ty + 1;
  57. int index_x = tx + 1;
  58. w[index] += ((ETA * delta[index_x] * ly[index_y]) + (MOMENTUM * oldw[index]));
  59. oldw[index] = ((ETA * delta[index_x] * ly[index_y]) + (MOMENTUM * oldw[index]));
  60. barrier(CLK_LOCAL_MEM_FENCE);
  61. if (ty == 0 && by ==0){
  62. w[index_x] += ((ETA * delta[index_x]) + (MOMENTUM * oldw[index_x]));
  63. oldw[index_x] = ((ETA * delta[index_x]) + (MOMENTUM * oldw[index_x]));
  64. }
  65. }
  66. #endif