Libsvm中grid.py文件的解读

1.导入相关文件

这里重点讲一下 _ all _ = ['find_parameters']

_all__ = ['find_parameters'] 是 Python 中用于定义模块级别的变量 __all__ 的语法**all**是一个包含模块中应该被公开(即可以通过 from module import * 导入)的变量名的列表

  • __all__ 是一个约定俗成的变量名,用于指定在使用 from module import * 语句时,应该导入哪些变量名。这样可以控制模块的命名空间,避免不必要的变量污染。

  • ['find_parameters'] 是一个包含在 __all__ 中的列表,其中包含了模块中应该被导入的变量名。在这个例子中,只有一个变量名 find_parameters 被包含在 __all__ 中。

通过这个设置,当其他模块使用 from module import * 导入这个模块时,只有 find_parameters 这个变量名会被导入,其他未在 __all__ 中指定的变量不会被导入。这是一种良好的编程实践,因为它可以提供更清晰的模块接口,避免不必要的命名冲突和变量污染。

2.GridOption类的定义

构造函数接收两个参数:dataset_pathname 和 options

根据操作系统设置svm-train.exe和gnuplot.exe 的路径,这个要根据自己系统的实际按照情况 来进行路径的设置。

默认参数的设置以及解析传入参数的函数parse_options

最后,检查 SVM 训练可执行文件路径、数据集路径和 Gnuplot 可执行文件路径的存在性。

python 复制代码
class GridOption:
    '''
    构造函数 __init__:
    接收两个参数 dataset_pathname 和 options
                dataset_pathname 是数据集的路径
                options 是一个包含其他配置选项的字典
    获取当前脚本所在目录,并根据操作系统设置 svmtrain_pathname 和 gnuplot_pathname
    '''
    def __init__(self, dataset_pathname, options):
        dirname = os.path.dirname(__file__)
        # 使用 sys.platform 来检查操作系统
        # 如果不是 Windows (sys.platform != 'win32'),则设置 svmtrain_pathname 为在当前脚本所在目录下的 '.../svm-train',并设置 gnuplot_pathname 为 '/usr/bin/gnuplot'
        if sys.platform != 'win32':
            self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
            self.gnuplot_pathname = '/usr/bin/gnuplot'
        else:
            # example for windows
            # 如果是 Windows,则设置 svmtrain_pathname 为在当前脚本所在目录下的 r'...\windows\svm-train.exe',并设置 gnuplot_pathname 为 r'c:\tmp\gnuplot\binary\pgnuplot.exe'
            self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
            # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
            self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
        # 默认参数的设置
        # 设置了一系列参数的默认值,例如 fold、c_begin、c_end、c_step、g_begin、g_end、g_step 等,用于定义网格搜索的参数范围和步长
        # 设置了 grid_with_c 和 grid_with_g 为 True,表示要在网格搜索中搜索 C 和 gamma 参数
        self.fold = 5
        self.c_begin, self.c_end, self.c_step = -5,  15,  2
        self.g_begin, self.g_end, self.g_step =  3, -15, -2
        self.grid_with_c, self.grid_with_g = True, True
        self.dataset_pathname = dataset_pathname  # 将传入的 dataset_pathname 赋值给 self.dataset_pathname
        self.dataset_title = os.path.split(dataset_pathname)[1]   # 提取数据集的标题部分,通过 os.path.split(dataset_pathname) 和 [1] 获取,赋值给 self.dataset_title
        self.out_pathname = '{0}.out'.format(self.dataset_title)  # 设置 out_pathname 为 '{0}.out',其中 {0} 是数据集标题
        self.png_pathname = '{0}.png'.format(self.dataset_title)  # 设置 png_pathname 为 '{0}.png',其中 {0} 是数据集标题
        self.pass_through_string = ' '  # 设置 pass_through_string 为一个空格
        self.resume_pathname = None     # 设置 resume_pathname 为 None
        self.parse_options(options)     # 调用 parse_options 方法,该方法用于解析传入的选项,并更新类的属性值

    # 定义了 parse_options 方法,该方法用于解析传入的选项列表,更新 GridOption 类的属性值
    def parse_options(self, options):
        # options 是传入的选项,可以是字符串,也可以是由字符串组成的列表
        # 如果 options 是字符串,通过 options.split() 将其分割成列表
        if type(options) == str:
            options = options.split()
        i = 0  # 初始化变量 i 为 0,用于迭代 options 列表
        # 初始化空列表 pass_through_options,用于存储未被解析的选项
        pass_through_options = []
        
        # 使用 while 循环遍历 options 列表
        # 通过检查当前选项,更新相应的 GridOption 类属性
        while i < len(options):
            '''
            -log2c 和 -log2g:解析参数范围和步长,如果值为 'null',则相应的网格搜索标志设为 False
            -v:设置交叉验证的折数
            -c 和 -g:抛出错误,提示使用 -log2c 和 -log2g
            -svmtrain:设置 SVM 训练可执行文件路径
            -gnuplot:设置 Gnuplot 可执行文件路径,如果值为 'null',则设为 None
            -out:设置输出文件路径,如果值为 'null',则设为 None
            -png:设置 PNG 文件路径
            -resume:设置恢复训练的文件路径,如果未提供则使用默认文件名
            '''
            if options[i] == '-log2c':
                i = i + 1
                if options[i] == 'null':
                    self.grid_with_c = False
                else:
                    self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
            elif options[i] == '-log2g':
                i = i + 1
                if options[i] == 'null':
                    self.grid_with_g = False
                else:
                    self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
            elif options[i] == '-v':
                i = i + 1
                self.fold = options[i]
            elif options[i] in ('-c','-g'):
                raise ValueError('Use -log2c and -log2g.')
            elif options[i] == '-svmtrain':
                i = i + 1
                self.svmtrain_pathname = options[i]
            elif options[i] == '-gnuplot':
                i = i + 1
                if options[i] == 'null':
                    self.gnuplot_pathname = None
                else:
                    self.gnuplot_pathname = options[i]
            elif options[i] == '-out':
                i = i + 1
                if options[i] == 'null':
                    self.out_pathname = None
                else:
                    self.out_pathname = options[i]
            elif options[i] == '-png':
                i = i + 1
                self.png_pathname = options[i]
            elif options[i] == '-resume':
                if i == (len(options)-1) or options[i+1].startswith('-'):
                    self.resume_pathname = self.dataset_title + '.out'
                else:
                    i = i + 1
                    self.resume_pathname = options[i]
            else:
                pass_through_options.append(options[i])  # 未识别的选项将被添加到 pass_through_options 列表中
            i = i + 1
        # 使用 ' '.join(pass_through_options) 将未识别的选项组合成一个字符串,更新 pass_through_string 属性
        self.pass_through_string = ' '.join(pass_through_options)

        # 检查 SVM 训练可执行文件路径、数据集路径和 Gnuplot 可执行文件路径的存在性
        if not os.path.exists(self.svmtrain_pathname):
            raise IOError('svm-train executable not found')
        if not os.path.exists(self.dataset_pathname):
            raise IOError('dataset not found')
        if self.resume_pathname and not os.path.exists(self.resume_pathname):
            raise IOError('file for resumption not found')  # 如果 resume_pathname 存在,检查其存在性
        if not self.grid_with_c and not self.grid_with_g:   # 如果同时设置了 -log2c 和 -log2g 为 False,抛出错误
            raise ValueError('-log2c and -log2g should not be null simultaneously')
        if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
        # 如果 Gnuplot 可执行文件不存在,输出错误信息并将其设为 None
            sys.stderr.write('gnuplot executable not found\n')
            self.gnuplot_pathname = None

补充"win32" 是 Windows 操作系统的平台标识符 。在 Python 中,sys.platform 返回一个字符串,表示当前运行 Python 解释器的平台。对于 Windows 系统,这个字符串通常是"win32"。所以,if sys.platform != 'win32' 这个条件语句检查当前操作系统是否为 Windows 之外的其他操作系统。

3. 定义redraw 函数,用于在图形界面中绘制 SVM 参数搜索的轮廓图

python 复制代码
def redraw(db,best_param,gnuplot,options,tofile=False):
    if len(db) == 0: return
    begin_level = round(max(x[2] for x in db)) - 3
    step_size = 0.5

    best_log2c,best_log2g,best_rate = best_param

    # if newly obtained c, g, or cv values are the same,
    # then stop redrawing the contour.
    if all(x[0] == db[0][0]  for x in db): return
    if all(x[1] == db[0][1]  for x in db): return
    if all(x[2] == db[0][2]  for x in db): return

    if tofile:
        gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
        gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
        #gnuplot.write(b"set term postscript color solid\n")
        #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
    elif sys.platform == 'win32':
        gnuplot.write(b"set term windows\n")
    else:
        gnuplot.write( b"set term x11\n")
    gnuplot.write(b"set xlabel \"log2(C)\"\n")
    gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
    gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
    gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
    gnuplot.write(b"set contour\n")
    gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
    gnuplot.write(b"unset surface\n")
    gnuplot.write(b"unset ztics\n")
    gnuplot.write(b"set view 0,0\n")
    gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
    gnuplot.write(b"unset label\n")
    gnuplot.write("set label \"Best log2(C) = {0}  log2(gamma) = {1}  accuracy = {2}%\" \
                  at screen 0.5,0.85 center\n". \
                  format(best_log2c, best_log2g, best_rate).encode())
    gnuplot.write("set label \"C = {0}  gamma = {1}\""
                  " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
    gnuplot.write(b"set key at screen 0.9,0.9\n")
    gnuplot.write(b"splot \"-\" with lines\n")

    db.sort(key = lambda x:(x[0], -x[1]))

    prevc = db[0][0]
    for line in db:
        if prevc != line[0]:
            gnuplot.write(b"\n")
            prevc = line[0]
        gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
    gnuplot.write(b"e\n")
    gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
    gnuplot.flush()

4. 函数calculate_jobs 的定义

该函数接受一个参数 options,并返回两个值:jobsresumed_jobs,同时里面嵌套定义了函数 range_f 和函数 permute_sequence。

函数的主要目的是生成一系列的任务(jobs),每个任务是一个参数组合,用于训练支持向量机(SVM)。这些参数是通过对给定的一组参数范围进行排列组合得到的。

python 复制代码
def calculate_jobs(options):

    def range_f(begin,end,step):
        # like range, but works on non-integer too
        seq = []
        while True:
            if step > 0 and begin > end: break
            if step < 0 and begin < end: break
            seq.append(begin)
            begin = begin + step
        return seq

    def permute_sequence(seq):
        n = len(seq)
        if n <= 1: return seq

        mid = int(n/2)
        left = permute_sequence(seq[:mid])
        right = permute_sequence(seq[mid+1:])

        ret = [seq[mid]]
        while left or right:
            if left: ret.append(left.pop(0))
            if right: ret.append(right.pop(0))

        return ret


    c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
    g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))

    if not options.grid_with_c:
        c_seq = [None]
    if not options.grid_with_g:
        g_seq = [None]

    nr_c = float(len(c_seq))
    nr_g = float(len(g_seq))
    i, j = 0, 0
    jobs = []

    while i < nr_c or j < nr_g:
        if i/nr_c < j/nr_g:
            # increase C resolution
            line = []
            for k in range(0,j):
                line.append((c_seq[i],g_seq[k]))
            i = i + 1
            jobs.append(line)
        else:
            # increase g resolution
            line = []
            for k in range(0,i):
                line.append((c_seq[k],g_seq[j]))
            j = j + 1
            jobs.append(line)

    resumed_jobs = {}

    if options.resume_pathname is None:
        return jobs, resumed_jobs

    for line in open(options.resume_pathname, 'r'):
        line = line.strip()
        rst = re.findall(r'rate=([0-9.]+)',line)
        if not rst:
            continue
        rate = float(rst[0])

        c, g = None, None
        rst = re.findall(r'log2c=([0-9.-]+)',line)
        if rst:
            c = float(rst[0])
        rst = re.findall(r'log2g=([0-9.-]+)',line)
        if rst:
            g = float(rst[0])

        resumed_jobs[(c,g)] = rate

    return jobs, resumed_jobs

range_f函数:

  • range_f 函数是一个自定义的函数,类似于内置函数 range,但可以处理非整数的步长。它生成一个序列,从 begin 开始,以 step 为步长,直到不再满足条件。

permute_sequence函数:

  • permute_sequence 函数用于对给定序列进行排列组合。它采用分而治之的方法,将序列分成两半,然后递归地对左右两半进行排列组合,最终将结果合并。

参数生成:

  • 使用 range_f 函数生成了两个序列 c_seqg_seq,分别表示参数 cg 的可能取值。如果选项 options.grid_with_coptions.grid_with_g 为 False,则相应的参数序列为单一值,即 [None]

生成任务列表:

  • 使用生成的参数序列,通过两个循环(while 循环)生成所有可能的参数组合,存储在 jobs 列表中。

处理恢复任务:

  • 如果存在恢复路径 options.resume_pathname,则从该路径读取已经完成的任务信息,提取出参数组合和对应的性能率,并存储在 resumed_jobs 字典中。

返回结果:

  • 最终,函数返回两个值:生成的任务列表 jobs 和已经完成的任务信息字典 resumed_jobs

这段代码主要用于生成一系列参数组合,以及处理从先前运行中恢复的任务信息。这类功能通常在超参数搜索和模型训练中使用,以便系统能够自动尝试多种参数组合。

5.类WorkerStopToken的定义

通常用作信号或标记,用于通信或控制多线程或多进程的执行流程。在这里, WorkerStopToken 的目的是作为一个简单的标记,用于通知工作线程停止或表示工作线程已经停止。在实际应用中,它可能会与其他线程或进程之间的通信机制一起使用,以实现协同工作或关闭。

class WorkerStopToken: :定义了一个新的类,类名为 WorkerStopToken

pass :在Python中,pass 是一个占位符语句,不执行任何操作。在这里,它被用作类的主体部分,表示这个类是一个空类,没有任何成员或方法。

6. 类Worker的定义

Worker类继承自Python中的Thread类 ,这个类表示一个工作线程,用于执行支持向量机(SVM)的训练任务,该类定义了三个函数:_ init _方法、run方法、get_cmd方法

python 复制代码
class Worker(Thread):
    def __init__(self,name,job_queue,result_queue,options):
        Thread.__init__(self)
        self.name = name
        self.job_queue = job_queue
        self.result_queue = result_queue
        self.options = options

    

__init__ 方法:

  • 初始化方法,接受四个参数:name(线程名称)、job_queue(任务队列)、result_queue(结果队列)、options(选项参数)
  • 将这些参数保存为实例变量(也可以说是成员变量),用于在线程运行时访问
  • self:表示对象的实例

python 复制代码
    def run(self):
        while True:
            (cexp,gexp) = self.job_queue.get()
            if cexp is WorkerStopToken:
                self.job_queue.put((cexp,gexp))
                # print('worker {0} stop.'.format(self.name))
                break
            try:
                c, g = None, None
                if cexp != None:
                    c = 2.0**cexp
                if gexp != None:
                    g = 2.0**gexp
                rate = self.run_one(c,g)
                if rate is None: raise RuntimeError('get no rate')
            except:
                # we failed, let others do that and we just quit

                traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])

                self.job_queue.put((cexp,gexp))
                sys.stderr.write('worker {0} quit.\n'.format(self.name))
                break
            else:
                self.result_queue.put((self.name,cexp,gexp,rate))

run 方法:

  • run 方法是 Thread 类的默认方法,在启动线程时会自动调用。这里是线程的主要执行逻辑。
  • 使用无限循环 (while True) 从任务队列 (job_queue) 获取任务,任务是 (cexp, gexp),其中 cexpgexp 表示对应的参数指数。
  • 如果接收到 WorkerStopToken,表示线程应该停止,将任务重新放回队列,并通过 break 退出循环,结束线程。
  • 否则,尝试将指数转换为实际的参数值 cg,然后调用 run_one 方法执行具体的 SVM 训练,并获取性能率。
  • 如果执行出错,将异常信息输出到标准错误流,并将任务重新放回队列,然后通过 sys.stderr.write 输出线程终止的信息,并通过 break 退出循环,结束线程。
  • 如果一切正常,将线程的名字、cexpgexp 和性能率放入结果队列 (result_queue)。

这段代码实现了一个工作线程的逻辑,用于执行 SVM 训练任务。它通过任务队列接收参数组合,执行训练,并将结果放入结果队列 。这样的多线程结构通常用于加速大规模参数搜索和训练任务


python 复制代码
    def get_cmd(self,c,g):
        options=self.options
        cmdline = '"' + options.svmtrain_pathname + '"'
        if options.grid_with_c:
            cmdline += ' -c {0} '.format(c)
        if options.grid_with_g:
            cmdline += ' -g {0} '.format(g)
        cmdline += ' -v {0} {1} {2} '.format\
            (options.fold,options.pass_through_string,options.dataset_pathname)
        return cmdline

get_cmd 方法:

  • 用于生成 SVM 训练的命令行字符串,其中包括 SVM 训练器的路径、参数 -c(如果启用)、参数 -g(如果启用)、参数 -v、折数、透传参数和数据集路径。

下面我再来详细地讲解一下get_cmd方法 :

def get_cmd(self,c,g)

定义了一个方法 get_cmd,接受两个参数 cg,表示 SVM 训练 的参数

options = self.options

将类实例中的 options 属性赋给局部变量 options,以便在后续代码中使用

cmdline = '"' + options.svmtrain_pathname + '"'

构建命令行字符串的开头部分,包含 SVM 训练器的路径。使用双引号将路径括起来,以防止 路径中包含空格时出现问题。

if options.grid_with_c:

检查选项 grid_with_c 是否为真,即是否启用了参数 c 的网格搜索

cmdline += ' -c {0} '.format(c)

如果启用了参数 c 的网格搜索,则将参数 c 的值添加到命令行字符串中

if options.grid_with_g:

检查选项 grid_with_g 是否为真,即是否启用了参数 g 的网格搜索

cmdline += ' -g {0} '.format(g)

如果启用了参数 g 的网格搜索,则将参数 g 的值添加到命令行字符串中

cmdline += ' -v {0} {1} {2} '.format(options.fold, options.pass_through_string, options.dataset_pathname)

添加 SVM 训练的其他参数,包括:

  • -v:表示要进行交叉验证
  • {0}:使用 options.fold 指定的折数
  • {1}:用户传递的额外参数
  • {2}:数据集的路径,由 options.dataset_pathname 指定

return cmdline

返回构建好的 SVM 训练命令行字符串

总体而言,这段代码的作用是根据给定的参数 cg 以及一些配置选项生成用于执行 SVM 训练的命令行字符串 。生成的命令行包括 SVM 训练器的路径、参数 -c(如果启用)、参数 -g(如果启用)、参数 -v、交叉验证的折数、额外参数和数据集的路径。

7.类LocalWorker的定义

定义了一个名为 LocalWorker 的类,它继承自先前提到的 Worker 类,并重写了 run_one 方法

python 复制代码
class LocalWorker(Worker):
    def run_one(self,c,g):
        cmdline = self.get_cmd(c,g)
        result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
        for line in result.readlines():
            if str(line).find('Cross') != -1:
                return float(line.split()[-1][0:-1])

run_one方法

该方法接受两个参数 cg,表示 SVM 训练的参数

cmdline = self.get_cmd(c,g)

调用父类 Workerget_cmd 方法,获取 SVM 训练的命令行字符串,并将其赋给 cmdline

result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout

使用 subprocess.Popen 创建一个新的进程,运行 SVM 训练的命令行,其中

  • cmdline 是要执行的命令行字符串
  • shell=True 表示使用系统的 shell 执行命令
  • stdout=PIPE 表示将命令的标准输出捕获到 result 变量中
  • stderr=PIPE 表示将命令的标准错误捕获,但在这段代码中没有使用
  • stdin=PIPE 表示标准输入连接到管道,但在这段代码中没有使用

for line in result.readlines():

遍历命令的标准输出的每一行

if str(line).find('Cross') != -1:

判断当前行是否包含字符串 'Cross'。如果包含,说明这一行包含了交叉验证的结果信息

return float(line.split()[-1][0:-1])

如果找到包含 'Cross' 的行,提取该行的最后一个单词,去掉末尾的换行符,并将其转换为浮点 数。这个值表示 SVM 训练的性能率。

总体而言,这段代码实现了在本地环境运行 SVM 训练任务的逻辑。它通过创建新的进程执行 SVM 训练命令行,并从命令的标准输出中提取包含交叉验证结果的行,最终返回性能率作为结果。

8.类SSHWorker的定义

python 复制代码
class SSHWorker(Worker):
    def __init__(self,name,job_queue,result_queue,host,options):
        Worker.__init__(self,name,job_queue,result_queue,options)
        self.host = host
        self.cwd = os.getcwd()
    def run_one(self,c,g):
        cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
            (self.host,self.cwd,self.get_cmd(c,g))
        result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
        for line in result.readlines():
            if str(line).find('Cross') != -1:
                return float(line.split()[-1][0:-1])

定义了一个名为 SSHWorker 的类,它同样继承自之前提到的 Worker 类,并进行了一些定制化。

该类定义了初始化函数和run_one函数

__init__方法

初始化方法,除了调用父类的初始化方法外,还接受一个额外的参数 host,表示远程主机的地 址。

self.host = host:将传入的 host 参数保存为实例变量,以便在后续代码中使用

self.cwd = os.getcwd():获取当前工作目录,并保存为实例变量 cwd

run_one方法

重写了 run_one 方法,该方法接受两个参数 cg,表示 SVM 训练的参数

cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"' .format (self.host, self.cwd, self.get_cmd(c,g))

构建了一个 SSH 命令行字符串,该命令行用于在远程主机上执行 SVM 训练任务

ssh -x -t -t:表示使用 SSH 连接,并在远程主机上执行命令

{0}:用传入的 host 替换占位符,表示远程主机的地址

"cd {1}; {2}":在远程主机上执行的命令,首先切换到当前工作目录(cwd),然后执行通过

调 用 get_cmd 方法生成的 SVM 训练命令

result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout:

使用 subprocess.Popen 创建一个新的进程,运行 SSH 命令行

  • cmdline 是要执行的 SSH 命令行字符串
  • stdout=PIPE 表示将命令的标准输出捕获到 result 变量中

for line in result.readlines():

遍历命令的标准输出的每一行

if str(line).find('Cross') != -1:

判断当前行是否包含字符串 'Cross'。如果包含,说明这一行包含了交叉验证的结果信息

return float(line.split()[-1][0:-1])

如果找到包含 'Cross' 的行,提取该行的最后一个单词,去掉末尾的换行符,并将其转换为浮点数。这个值表示在远程主机上运行 SVM 训练的性能率

总体而言,这段代码实现了在远程主机上通过 SSH 运行 SVM 训练任务的逻辑。它构建了相应的 SSH 命令行,执行远程任务,并从命令的标准输出中提取包含交叉验证结果的行,最终返回性能率作为结果。

9.类TelnetWorker的定义

python 复制代码
class TelnetWorker(Worker):
    def __init__(self,name,job_queue,result_queue,host,username,password,options):
        Worker.__init__(self,name,job_queue,result_queue,options)
        self.host = host
        self.username = username
        self.password = password
    def run(self):
        import telnetlib
        self.tn = tn = telnetlib.Telnet(self.host)
        tn.read_until('login: ')
        tn.write(self.username + '\n')
        tn.read_until('Password: ')
        tn.write(self.password + '\n')

        # XXX: how to know whether login is successful?
        tn.read_until(self.username)
        #
        print('login ok', self.host)
        tn.write('cd '+os.getcwd()+'\n')
        Worker.run(self)
        tn.write('exit\n')
    def run_one(self,c,g):
        cmdline = self.get_cmd(c,g)
        result = self.tn.write(cmdline+'\n')
        (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
        for line in output.split('\n'):
            if str(line).find('Cross') != -1:
                return float(line.split()[-1][0:-1])

总体而言,这段代码实现了在远程主机上通过 Telnet 运行 SVM 训练任务的逻辑。它通过 Telnet 协议连接远程主机,执行相应的命令,并从输出中提取包含交叉验证结果的行,最终返回性能率作为结果。需要注意的是,代码中对登录成功的判断逻辑可能需要进一步完善。

10.函数find_parameters的定义

这段代码实现了对 SVM 模型参数的并行搜索和优化,通过多线程/进程执行不同参数组合的训练 任务,然后比较性能,最终找到最佳的参数组合。

用于参数搜索和优化的部分,具体来说,它使用了多线程/进程的方式来执行 SVM 参数的搜索工作

python 复制代码
def find_parameters(dataset_pathname, options=''):

    def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
        if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
            best_rate,best_c,best_g = rate,c,g
        stdout_str = '[{0}] {1} {2} (best '.format\
            (worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
        output_str = ''
        if c != None:
            stdout_str += 'c={0}, '.format(2.0**best_c)
            output_str += 'log2c={0} '.format(c)
        if g != None:
            stdout_str += 'g={0}, '.format(2.0**best_g)
            output_str += 'log2g={0} '.format(g)
        stdout_str += 'rate={0})'.format(best_rate)
        print(stdout_str)
        if options.out_pathname and not resumed:
            output_str += 'rate={0}\n'.format(rate)
            result_file.write(output_str)
            result_file.flush()

        return best_c,best_g,best_rate

def find_parameters(dataset_pathname, options=''):

  • 定义了一个名为 find_parameters 的函数,用于寻找 SVM 模型的最佳参数

def update_param(c, g, rate, best_c, best_g, best_rate, worker, resumed):

  • 定义了一个辅助函数 update_param,用于更新最佳参数和最佳性能率

python 复制代码
options = GridOption(dataset_pathname, options);

    if options.gnuplot_pathname:
        gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
    else:
        gnuplot = None

options = GridOption(dataset_pathname, options);

  • 使用 GridOption 类处理参数选项,GridOption 类是对参数进行解析和处理的一个自定义类

if options.gnuplot_pathname:

  • 判断是否提供了 gnuplot 路径,如果提供了,则创建一个与 gnuplot 进程进行通信的管道

python 复制代码
  # put jobs in queue

    jobs,resumed_jobs = calculate_jobs(options)
    job_queue = Queue(0)
    result_queue = Queue(0)

    for (c,g) in resumed_jobs:
        result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))

    for line in jobs:
        for (c,g) in line:
            if (c,g) not in resumed_jobs:
                job_queue.put((c,g))

    # hack the queue to become a stack --
    # this is important when some thread
    # failed and re-put a job. It we still
    # use FIFO, the job will be put
    # into the end of the queue, and the graph
    # will only be updated in the end

    job_queue._put = job_queue.queue.appendleft

jobs, resumed_jobs = calculate_jobs(options)

调用 calculate_jobs 函数,生成需要执行的任务列表 jobs 和已经恢复的任务列表 resumed_jobs

job_queue = Queue(0)result_queue = Queue(0):

创建两个队列,job_queue 用于存放待执行的任务,result_queue 用于存放执行结果

for (c, g) in resumed_jobs:for line in jobs:

  • 循环遍历已经恢复的任务和待执行的任务

job_queue._put = job_queue.queue.appendleft

job_queue_put 方法指向 appendleft 方法,将队列变成一个栈,以确保重新放入的任务在队列头部


python 复制代码
 # fire telnet workers

    if telnet_workers:
        nr_telnet_worker = len(telnet_workers)
        username = getpass.getuser()
        password = getpass.getpass()
        for host in telnet_workers:
            worker = TelnetWorker(host,job_queue,result_queue,
                     host,username,password,options)
            worker.start()

    # fire ssh workers

    if ssh_workers:
        for host in ssh_workers:
            worker = SSHWorker(host,job_queue,result_queue,host,options)
            worker.start()

    # fire local workers

    for i in range(nr_local_worker):
        worker = LocalWorker('local',job_queue,result_queue,options)
        worker.start()

    # gather results

    done_jobs = {}

    if options.out_pathname:
        if options.resume_pathname:
            result_file = open(options.out_pathname, 'a')
        else:
            result_file = open(options.out_pathname, 'w')

   

if telnet_workers:if ssh_workers:

根据是否提供了 Telnet 或 SSH 主机列表,启动相应的 TelnetWorker 或 SSHWorker

for i in range(nr_local_worker): 启动本地工作线程,数量由 nr_local_worker 决定

**done_jobs = {}:**用于存放已完成的任务及其结果

**if options.out_pathname:**如果提供了输出路径,则打开一个文件用于记录结果


python 复制代码
    db = []
    best_rate = -1
    best_c,best_g = None,None

    for (c,g) in resumed_jobs:
        rate = resumed_jobs[(c,g)]
        best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)

    for line in jobs:
        for (c,g) in line:
            while (c,g) not in done_jobs:
                (worker,c1,g1,rate1) = result_queue.get()
                done_jobs[(c1,g1)] = rate1
                if (c1,g1) not in resumed_jobs:
                    best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
            db.append((c,g,done_jobs[(c,g)]))
        if gnuplot and options.grid_with_c and options.grid_with_g:
            redraw(db,[best_c, best_g, best_rate],gnuplot,options)
            redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)

**db = []best_rate = -1:**用于存放任务执行结果的数据库和记录最佳性能率的变量

for (c, g) in resumed_jobs: 遍历已恢复的任务,更新最佳参数和最佳性能率

for line in jobs: 遍历待执行的任务

while (c, g) not in done_jobs: 循环等待任务执行完成,并将执行结果放入 done_jobs

**(worker, c1, g1, rate1) = result_queue.get():**从结果队列中获取执行结果

**db.append((c, g, done_jobs[(c, g)])):**将任务执行结果加入数据库

if gnuplot and options.grid_with_c and options.grid_with_g:

如果提供了 gnuplot 路径,并且需要绘制图形,则调用 redraw 函数绘制图形


python 复制代码
    if options.out_pathname:
        result_file.close()
    job_queue.put((WorkerStopToken,None))
    best_param, best_cg  = {}, []
    if best_c != None:
        best_param['c'] = 2.0**best_c
        best_cg += [2.0**best_c]
    if best_g != None:
        best_param['g'] = 2.0**best_g
        best_cg += [2.0**best_g]
    print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))

    return best_rate, best_param

if options.out_pathname:

  • 如果提供了输出路径,就关闭之前打开的文件result_file

job_queue.put((WorkerStopToken, None))

  • 向任务队列中放入停止信号,以停止工作线程

best_param, best_cg = {}, []print('{0} {1}'.format(' '.join(map(str, best_cg)), best_rate))

  • 输出最佳参数和最佳性能率

return best_rate, best_param 返回最佳性能率和最佳参数

11.程序入口函数的定义

这是一个命令行工具的入口,用于解析命令行参数并调用 find_parameters 函数进行参数搜索

python 复制代码
if __name__ == '__main__':

    def exit_with_help():
        print("""\
Usage: grid.py [grid_options] [svm_options] dataset

grid_options :
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
    begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
    "null"         -- do not grid with c
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
    begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
    "null"         -- do not grid with g
-v n : n-fold cross validation (default 5)
-svmtrain pathname : set svm executable path and name
-gnuplot {pathname | "null"} :
    pathname -- set gnuplot executable path and name
    "null"   -- do not plot
-out {pathname | "null"} : (default dataset.out)
    pathname -- set output file path and name
    "null"   -- do not output file
-png pathname : set graphic output file path and name (default dataset.png)
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
    This is experimental. Try this option only if some parameters have been checked for the SAME data.

svm_options : additional options for svm-train""")
        sys.exit(1)

   
    if len(sys.argv) < 2:
        exit_with_help()
    dataset_pathname = sys.argv[-1]
    options = sys.argv[1:-1]
    try:
        find_parameters(dataset_pathname, options)
    except (IOError,ValueError) as e:
        sys.stderr.write(str(e) + '\n')
        sys.stderr.write('Try "grid.py" for more information.\n')
        sys.exit(1)

if __name__ == '__main__':

  • 这是 Python 中的惯用写法,表示以下代码块将在作为脚本直接执行时运行

def exit_with_help():

  • 定义了一个辅助函数 exit_with_help,用于打印使用帮助信息并退出程序

print('' '' ''\ ...'' '' '')和 sys.exit(1)

  • 打印使用帮助信息,并使用 sys.exit(1) 退出程序

if len(sys.argv) < 2:exit_with_help()

如果命令行参数数量小于 2,则调用 exit_with_help 函数打印使用帮助信息并退出程序

dataset_pathname = sys.argv[-1]options = sys.argv[1:-1]:

  • 将命令行参数中的最后一个参数(数据集路径)赋值给 dataset_pathname,将除第一个参数和最后一个参数外的其他参数赋值给 options

try: ... except (IOError, ValueError) as e: ...

  • 使用 try...except 结构捕获可能发生的 IOErrorValueError 异常
  • try 块中调用 find_parameters 函数,传入数据集路径和其他参数
  • 如果捕获到异常,则将异常信息写入标准错误输出,打印提示信息,并退出程序

总体而言,这段代码实现了一个命令行工具的入口,用于解析命令行参数并调用 find_parameters 函数进行参数搜索。如果命令行参数不符合要求或者执行过程中出现异常,将打印使用帮助信息或错误信息,并退出程序。

相关推荐
数据小爬虫@2 小时前
深入解析:使用 Python 爬虫获取苏宁商品详情
开发语言·爬虫·python
健胃消食片片片片2 小时前
Python爬虫技术:高效数据收集与深度挖掘
开发语言·爬虫·python
ℳ₯㎕ddzོꦿ࿐5 小时前
解决Python 在 Flask 开发模式下定时任务启动两次的问题
开发语言·python·flask
CodeClimb5 小时前
【华为OD-E卷 - 第k个排列 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
一水鉴天5 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
Channing Lewis5 小时前
什么是 Flask 的蓝图(Blueprint)
后端·python·flask
B站计算机毕业设计超人5 小时前
计算机毕业设计hadoop+spark股票基金推荐系统 股票基金预测系统 股票基金可视化系统 股票基金数据分析 股票基金大数据 股票基金爬虫
大数据·hadoop·python·spark·课程设计·数据可视化·推荐算法
觅远6 小时前
python+playwright自动化测试(四):元素操作(键盘鼠标事件)、文件上传
python·自动化
ghostwritten7 小时前
Python FastAPI 实战应用指南
开发语言·python·fastapi
CM莫问7 小时前
python实战(十五)——中文手写体数字图像CNN分类
人工智能·python·深度学习·算法·cnn·图像分类·手写体识别