from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, BatchNormalization, UpSampling2D from keras.layers import LeakyReLU def u_net_small(inputs, num_classes=1): # 8 use_bias = False down0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(inputs) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.)(down0) down0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(down0) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.)(down0) down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) # 4 down1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0_pool) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.)(down1) down1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down1) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.)(down1) down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) # 2 down2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down1_pool) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.)(down2) down2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down2) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.)(down2) down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) center = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down2_pool) center = BatchNormalization()(center) center = LeakyReLU(alpha=0.)(center) center = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(center) center = BatchNormalization()(center) center = LeakyReLU(alpha=0.)(center) # 2 up2 = UpSampling2D((2, 2))(center) up2 = concatenate([down2, up2], axis=3) up2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.)(up2) up2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.)(up2) up2 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.)(up2) # 4 up1 = UpSampling2D((2, 2))(up2) up1 = concatenate([down1, up1], axis=3) up1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.)(up1) up1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.)(up1) up1 = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.)(up1) # 8 up0 = UpSampling2D((2, 2))(up1) up0 = concatenate([down0, up0], axis=3) up0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.)(up0) up0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.)(up0) up0 = Conv2D(8, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.)(up0) # classify # classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0) return up0