• python执行SVM分类


      利用sklearn执行SVM分类时速度很慢,采用了多进程机制。

      一般多进程用于独立文件操作,各进程之间最好不通信。但此处,单幅影像SVM分类就很慢,只能添加多进程,由于不同进程之间不能共用一个变量(即使共用一个变量,还需要添加变量锁),故将单幅影像分为小幅,每小幅对应一个进程,每个进程对该小幅数据分类完成后,将处理结果输出到临时路径的临时文件中,最好再将临时文件按照顺序合成一个完整的分类结果。

    def sub_process(in_data, fill_value, classify_model, temp_dir, filename, y):
    """

    :param in_data:
    :param fill_value:
    :param classify_model:
    :param temp_dir:
    :param filename:
    :param y:
    :return:
    """
    try:
    nb_, nl_ ,ns_ = in_data.shape
    # fill_value = np.nan

    if np.isnan(fill_value):
    unvalid_index = np.where(in_data != in_data)
    else:
    unvalid_index = np.where(in_data == fill_value)
    nb_index = unvalid_index[0]
    nl_index = unvalid_index[1]
    ns_index = unvalid_index[2]
    for i,j in zip(nl_index, ns_index):
    in_data[:, i, j] = fill_value

    temp_data = in_data[0,:,:]
    if np.isnan(fill_value):
    valid_index = np.where(temp_data == temp_data)
    else:
    valid_index = np.where(temp_data != fill_value)

    # 获取有效特征数据
    valid_in_data = []
    for i in range(nb_):
    in_data_ = in_data[i, :, :]
    valid_in_data.append(in_data_[valid_index])
    del in_data_

    valid_in_data = np.array(valid_in_data)
    print(y, "Start predicting")
    prediction = classify_model.predict(np.transpose(valid_in_data))
    print(y, "Finish predicting")

    arr = np.zeros([nl_, ns_], dtype=np.byte)
    arr[valid_index] = prediction.astype("float").astype("int8")
    print(y, np.min(arr), np.max(arr))

    # out_band.WriteArray(arr, 0, y)
    # 多进程间不能共享被修改的变量(即使实现共享,还需要添加变量锁,降低效率)
    # class_arr[y:(y+nl_), :] = arr

    outfile = os.path.join(temp_dir, filename+"_"+str(y)+".tif")
    driver = gdal.GetDriverByName("GTiff")
    outds = driver.Create(outfile, ns_, nl_, 1, gdal.GDT_Byte, options=["COMPRESS=LZW"])
    outband = outds.GetRasterBand(1)
    outband.WriteArray(arr)
    del arr, outband, outds

    del prediction, valid_in_data, valid_index, unvalid_index, in_data

    except Exception as error_msg:
    print(str(error_msg))

    def classify():
      ......
      pools = Pool(self.num_process)
      for y in range(0, nl, block_ysize):
       if y + block_ysize < nl:
      rows = block_ysize
      else:
       rows = nl - y

      in_data = in_ds.ReadAsArray(0, y, block_xsize, rows)

      pools.apply_async(sub_process, args=(in_data, self.fill_value, classify_model, self.temp_dir, output_filename, y))
      pools.close()
      pools.join()

      # 整合各个进程的处理结果
      tempfiles = glob(os.path.join(self.temp_dir, output_filename+"*.tif"))
      for tempfile_ in tempfiles:
      temp_filename = os.path.splitext(os.path.basename(tempfile_))[0]
      strs = temp_filename.split("_")
      y = int(strs[-1])
      temp_ds = gdal.Open(tempfile_)
      temp_data = temp_ds.ReadAsArray()
      out_band.WriteArray(temp_data, 0, y)
      del temp_ds, temp_data

      del out_band, out_ds, in_ds
    
    
     
  • 相关阅读:
    转载:PHP JSON_ENCODE 不编码中文汉字的方法
    【TP3.2】:日志记录和查看
    PHP原生:分享一个轻量级的缓存类=>cache.php
    python: 基本的日期与时间转换
    python: 随机选择
    计算机bit是什么意思
    Python: 矩阵与线性代数运算
    Python numpy 安装以及处理报错 is not a supported wheel on this platform
    Python: 大型数组运算
    Python numpy有什么用?
  • 原文地址:https://www.cnblogs.com/jkmlscy/p/15505913.html
Copyright © 2020-2023  润新知