利用sklearn执行SVM分类时速度很慢,采用了多进程机制。
一般多进程用于独立文件操作,各进程之间最好不通信。但此处,单幅影像SVM分类就很慢,只能添加多进程,由于不同进程之间不能共用一个变量(即使共用一个变量,还需要添加变量锁),故将单幅影像分为小幅,每小幅对应一个进程,每个进程对该小幅数据分类完成后,将处理结果输出到临时路径的临时文件中,最好再将临时文件按照顺序合成一个完整的分类结果。
def sub_process(in_data, fill_value, classify_model, temp_dir, filename, y):
"""
:param in_data:
:param fill_value:
:param classify_model:
:param temp_dir:
:param filename:
:param y:
:return:
"""
try:
nb_, nl_ ,ns_ = in_data.shape
# fill_value = np.nan
if np.isnan(fill_value):
unvalid_index = np.where(in_data != in_data)
else:
unvalid_index = np.where(in_data == fill_value)
nb_index = unvalid_index[0]
nl_index = unvalid_index[1]
ns_index = unvalid_index[2]
for i,j in zip(nl_index, ns_index):
in_data[:, i, j] = fill_value
temp_data = in_data[0,:,:]
if np.isnan(fill_value):
valid_index = np.where(temp_data == temp_data)
else:
valid_index = np.where(temp_data != fill_value)
# 获取有效特征数据
valid_in_data = []
for i in range(nb_):
in_data_ = in_data[i, :, :]
valid_in_data.append(in_data_[valid_index])
del in_data_
valid_in_data = np.array(valid_in_data)
print(y, "Start predicting")
prediction = classify_model.predict(np.transpose(valid_in_data))
print(y, "Finish predicting")
arr = np.zeros([nl_, ns_], dtype=np.byte)
arr[valid_index] = prediction.astype("float").astype("int8")
print(y, np.min(arr), np.max(arr))
# out_band.WriteArray(arr, 0, y)
# 多进程间不能共享被修改的变量(即使实现共享,还需要添加变量锁,降低效率)
# class_arr[y:(y+nl_), :] = arr
outfile = os.path.join(temp_dir, filename+"_"+str(y)+".tif")
driver = gdal.GetDriverByName("GTiff")
outds = driver.Create(outfile, ns_, nl_, 1, gdal.GDT_Byte, options=["COMPRESS=LZW"])
outband = outds.GetRasterBand(1)
outband.WriteArray(arr)
del arr, outband, outds
del prediction, valid_in_data, valid_index, unvalid_index, in_data
except Exception as error_msg:
print(str(error_msg))
def classify():
......
pools = Pool(self.num_process)
for y in range(0, nl, block_ysize):
if y + block_ysize < nl:
rows = block_ysize
else:
rows = nl - y
in_data = in_ds.ReadAsArray(0, y, block_xsize, rows)
pools.apply_async(sub_process, args=(in_data, self.fill_value, classify_model, self.temp_dir, output_filename, y))
pools.close()
pools.join()
# 整合各个进程的处理结果
tempfiles = glob(os.path.join(self.temp_dir, output_filename+"*.tif"))
for tempfile_ in tempfiles:
temp_filename = os.path.splitext(os.path.basename(tempfile_))[0]
strs = temp_filename.split("_")
y = int(strs[-1])
temp_ds = gdal.Open(tempfile_)
temp_data = temp_ds.ReadAsArray()
out_band.WriteArray(temp_data, 0, y)
del temp_ds, temp_data
del out_band, out_ds, in_ds