u_net.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, BatchNormalization, UpSampling2D
  2. from keras.layers import LeakyReLU
  3. def u_net_small(inputs, num_classes=1):
  4. # 8
  5. use_bias = False
  6. down0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(inputs)
  7. down0 = BatchNormalization()(down0)
  8. down0 = LeakyReLU(alpha=0.)(down0)
  9. down0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(down0)
  10. down0 = BatchNormalization()(down0)
  11. down0 = LeakyReLU(alpha=0.)(down0)
  12. down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
  13. # 4
  14. down1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0_pool)
  15. down1 = BatchNormalization()(down1)
  16. down1 = LeakyReLU(alpha=0.)(down1)
  17. down1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down1)
  18. down1 = BatchNormalization()(down1)
  19. down1 = LeakyReLU(alpha=0.)(down1)
  20. down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
  21. # 2
  22. down2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down1_pool)
  23. down2 = BatchNormalization()(down2)
  24. down2 = LeakyReLU(alpha=0.)(down2)
  25. down2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down2)
  26. down2 = BatchNormalization()(down2)
  27. down2 = LeakyReLU(alpha=0.)(down2)
  28. down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
  29. center = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down2_pool)
  30. center = BatchNormalization()(center)
  31. center = LeakyReLU(alpha=0.)(center)
  32. center = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(center)
  33. center = BatchNormalization()(center)
  34. center = LeakyReLU(alpha=0.)(center)
  35. # 2
  36. up2 = UpSampling2D((2, 2))(center)
  37. up2 = concatenate([down2, up2], axis=3)
  38. up2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up2)
  39. up2 = BatchNormalization()(up2)
  40. up2 = LeakyReLU(alpha=0.)(up2)
  41. up2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up2)
  42. up2 = BatchNormalization()(up2)
  43. up2 = LeakyReLU(alpha=0.)(up2)
  44. up2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up2)
  45. up2 = BatchNormalization()(up2)
  46. up2 = LeakyReLU(alpha=0.)(up2)
  47. # 4
  48. up1 = UpSampling2D((2, 2))(up2)
  49. up1 = concatenate([down1, up1], axis=3)
  50. up1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up1)
  51. up1 = BatchNormalization()(up1)
  52. up1 = LeakyReLU(alpha=0.)(up1)
  53. up1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up1)
  54. up1 = BatchNormalization()(up1)
  55. up1 = LeakyReLU(alpha=0.)(up1)
  56. up1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up1)
  57. up1 = BatchNormalization()(up1)
  58. up1 = LeakyReLU(alpha=0.)(up1)
  59. # 8
  60. up0 = UpSampling2D((2, 2))(up1)
  61. up0 = concatenate([down0, up0], axis=3)
  62. up0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(up0)
  63. up0 = BatchNormalization()(up0)
  64. up0 = LeakyReLU(alpha=0.)(up0)
  65. up0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(up0)
  66. up0 = BatchNormalization()(up0)
  67. up0 = LeakyReLU(alpha=0.)(up0)
  68. up0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(up0)
  69. up0 = BatchNormalization()(up0)
  70. up0 = LeakyReLU(alpha=0.)(up0)
  71. # classify
  72. # classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0)
  73. return up0