Python: 编写面向对象的代码

面向过程

程序应该是非常抽象的,就像“下载网页、计算频率、打印费波拉悉数列”一样易懂。

就拿随手遇到的项目来说。在数据预处理工作中,目标服务器以*.json.gz的形式存放数据,我们需要将数据下载下来。*以时间戳命名无序排列。处理时可以利用用这个值记录任务的完成进度。现在我们把将整个流程能翻译成Python程序:

def downloadDataFromURL(dest_data, local_dir):
    # function: use urllib download data file to local
    try:
        urllib.urlretrieve(dest_data, local_dir)
    except:
        print '\tError retrieving the URL:', dest_data
# directly running
if __name__ == '__main__':
    if os.path.isfile(tsFilePath):
        print 'read timestamp file: ' + tsFilePath
        try:
            tsFile = open(tsFilePath, 'r+')
            oldTimestamp = int(True and tsFile.read(10) or 0) # use and-or to pass empty file content
            print oldTimestamp
        except IOError as (errno, strerror):
            print "I/O error({0}): {1}".format(errno, strerror)
        except:
            print "Unexpected error:", sys.exc_info()[0]
            raise
    else:
        print 'No such file or directory: ' + tsFilePath
        tsFile = open(tsFilePath, 'w+')    
    # start 4 worker processes
    pool = Pool(POOLNUMBER)
    html = requests.get(serverUrl).text
    dirSoup = BeautifulSoup(html, 'html.parser')
    # get tag array
    aTags = dirSoup.find_all('a', href=re.compile("20"))
    for aTag in aTags:
        dataDirUrl =  aTag['href']
        html = requests.get(serverUrl + dataDirUrl).text
        secDirSoup = BeautifulSoup(html, 'html.parser')
        # filter gz file source
        srcTags= secDirSoup.find_all('a', href=re.compile("gz"))
        for srcTag in srcTags:
            # ! need to papare timestamp 
            dataFile = srcTag['href']
            # compare timestamp
            if int(dataFile[0:10]) > oldTimestamp:
                # run async function on Pool
                pool.apply_async(downloadDataFromURL, args=(serverUrl + dataDirUrl + dataFile, winLocalPath))
    # finish and store timestamp
    pool.close() # close after wait process finish
    pool.join()  # wait process finish
    # save new timestamp
    tsFile.seek(0, 0);
    tsFile.write(str(newTimestamp));
    print 'Scraping finish'

可以看到,除了把下载过程单独抽象出来,其他部分都是直接写下来。一开始只是当作脚本写,这种方法会导致流程中的各部分耦合度高,而耦合太高在开发中会增加修改成本和重复代码。但是这样子的程序是无法在别的模块复用的。当需要时,外部模块只能把这个功能模块当作一条命令调用。而这种方法唯一的优点也只是开始时的开发速度快,能够快速完成基本功能。

面向对象

上面的写法太丑,没有什么可读性,应该更加抽象化成对象。下面就用面向对象的风格重构代码。首先将任务流程重新理清楚。在服务器去获取下载文件的地址后,遍历这个地址表下载文件。在完成每个文件下载时,记录时间戳文件名以保证下次使用和错误恢复。

我们以FileDownloadTask来命名这个类,在Python中是使用__init__()方法来初始化对的。在这个类中我们需要初始化服务器的地址、存放文件的本地路径和存放时间戳的文件等信息。

def __init__(self, server_url, target_folder='./', last_run_time_file='last_run_time'):
    self.server_url = server_url
    self.target_folder = target_folder
    self.last_run_time_file = target_folder + last_run_time_file
    self.last_run_time = 0
    self.__load_last_run_time()
    self.lock = threading.Lock()

通过文件路径获得最后运行成功的时间戳,我们又可以把读取时间文件封装成一个函数。在这里我使用较安全的方式(try-except-finally)来打开和读取文件.

def __load_last_run_time(self):
    try:
        f = open(self.last_run_time_file)
    except:
        logging.info('Last run time file not found, set 0 as default')
        self.last_run_time = 0
        return self.last_run_time
    try:
        self.last_run_time = int(f.read())
    except:
        logging.info('Read last run time failed, set 0 as default')
        self.last_run_time = 0
    finally:
        f.close()
    return self.last_run_time

而下载任务我们通过调用.download_source_file()来实现。这里的多线程我们是可以控制的,默认为cpu的核数。我们只需重新下载上次成功下载之后的文件。

def download_source_file(self, thread_num=None):
    if thread_num is None:
        pool = ThreadPool()  # default thread_num is cpu number
    else:
        pool = ThreadPool(thread_num)  # default thread_num is cpu number

    source_urls = map(lambda i: self.server_url + i,
                      filter(lambda i: re.match(r'^\d\d\d\d\.\d\d\.\d\d-\d\d\.\d\d\.\d\d', i) and
                             self.__parse_timestamp_from_url(i) > self.last_run_time,
                             self.__all_href_from_remote(self.server_url)))

    def _download_one(source_url):
        # show later
    pool.map(_download_one, source_urls)
    pool.close()

这里面有几个类私用方法,.__parse_timestamp_from_url()是解析时间戳,.__all_href_from_remote()是获得文件下载地址,._download_one()是下载单个文件的方法。注意参数sourceurl是相对路径下的文件地址。通过方法第一步的map(),filenameurls存放的是(filename, fileurl),fileurl是绝对路径。在遍历这个元祖中,进行最多3次的文件下载。在文件下载时添加'.download'后缀名,只有成功下载才重命名回原来的文件名。这样的原因时为了防止下载到一般的文件导致下一步的处理出差错。

def _download_one(source_url):
    file_name_urls = map(lambda i: (i, source_url + i),
                            filter(lambda i: re.search(r'json\.gz$', i),
                                self.__all_href_from_remote(source_url)))
    for file_name, file_url in file_name_urls:
        logging.info('Downloading %s from %s...' % (file_name, file_url))
        # try download for most 3 times
        for x in range(3):
            try:
                urllib.urlretrieve(file_url, self.target_folder + file_name + '.download')
                tmp = self.__parse_timestamp_from_url(file_url)
                with self.lock:
                    self.last_run_time = max(self.last_run_time, tmp)
                    os.rename(self.target_folder + file_name + '.download', self.target_folder + file_name)
                    self.__write_last_run_time()
                return
            except Exception, e:
                continue
        logging.error('Download from %s failed: %s' % (file_url, format(e)))

其实这个部分是老大也参与到了修改,所以才能写得非常优美。在项目的后续工作中,这种思路应该一直伴随。只有开始的不怕麻烦,才能达到最终的懒惰。而这篇文章也当作我的Python入门吧。