Kaynağa Gözat

First Commit

fangjiasheng 4 yıl önce
ebeveyn
işleme
b0a21e6371
100 değiştirilmiş dosya ile 17888 ekleme ve 7 silme
  1. 10 7
      .gitignore
  2. 0 0
      format_convert/UnRAR.exe
  3. 2459 0
      format_convert/convert.py
  4. 38 0
      format_convert/get_memory_info.py
  5. 6 0
      format_convert/judge_platform.py
  6. 117 0
      format_convert/libreoffice_interface.py
  7. 3 0
      format_convert/swf/__init__.py
  8. 188 0
      format_convert/swf/actions.py
  9. 344 0
      format_convert/swf/consts.py
  10. 1437 0
      format_convert/swf/data.py
  11. 1065 0
      format_convert/swf/export.py
  12. 229 0
      format_convert/swf/filters.py
  13. 371 0
      format_convert/swf/geom.py
  14. 171 0
      format_convert/swf/movie.py
  15. 81 0
      format_convert/swf/sound.py
  16. 499 0
      format_convert/swf/stream.py
  17. 2655 0
      format_convert/swf/tag.py
  18. 55 0
      format_convert/swf/utils.py
  19. 278 0
      format_convert/table_correct.py
  20. BIN
      format_convert/temp-0.5795441.jpg
  21. BIN
      format_convert/temp0.0.jpg
  22. BIN
      format_convert/temp0.jpg
  23. 0 0
      format_convert/temp1.1368683772161603e-13.jpg
  24. 0 0
      format_convert/temp107.52.jpg
  25. BIN
      format_convert/temp211.jpg
  26. BIN
      format_convert/temp232.61.jpg
  27. 0 0
      format_convert/temp31312.0.jpg
  28. BIN
      format_convert/temp316.63.jpg
  29. BIN
      format_convert/temp349.4635.jpg
  30. 0 0
      format_convert/temp350.0.jpg
  31. BIN
      format_convert/temp398.0.jpg
  32. 0 0
      format_convert/temp80.64.jpg
  33. BIN
      format_convert/temp90.0.jpg
  34. 411 0
      format_convert/test_ocr_interface.py
  35. 8 0
      format_convert/test_walk.py
  36. 14 0
      format_convert/testswf.py
  37. 198 0
      format_convert/timeout_decorator.py
  38. 0 0
      ocr/model/2.0/cls/inference.pdiparams
  39. 0 0
      ocr/model/2.0/cls/inference.pdiparams.info
  40. 0 0
      ocr/model/2.0/cls/inference.pdmodel
  41. 0 0
      ocr/model/2.0/det/inference.pdiparams
  42. 0 0
      ocr/model/2.0/det/inference.pdiparams.info
  43. 0 0
      ocr/model/2.0/det/inference.pdmodel
  44. 0 0
      ocr/model/2.0/rec/ch/inference.pdiparams
  45. 0 0
      ocr/model/2.0/rec/ch/inference.pdiparams.info
  46. 0 0
      ocr/model/2.0/rec/ch/inference.pdmodel
  47. 0 0
      ocr/model/2.0/rec/ch/origin_model/mobile/inference.pdiparams
  48. 0 0
      ocr/model/2.0/rec/ch/origin_model/mobile/inference.pdiparams.info
  49. 0 0
      ocr/model/2.0/rec/ch/origin_model/mobile/inference.pdmodel
  50. 0 0
      ocr/model/2.0/rec/ch/origin_model/server/inference.pdiparams
  51. 0 0
      ocr/model/2.0/rec/ch/origin_model/server/inference.pdiparams.info
  52. 0 0
      ocr/model/2.0/rec/ch/origin_model/server/inference.pdmodel
  53. 0 0
      ocr/model/2.0/rec/ch/production_model/mobile/inference.pdiparams
  54. 0 0
      ocr/model/2.0/rec/ch/production_model/mobile/inference.pdiparams.info
  55. 0 0
      ocr/model/2.0/rec/ch/production_model/mobile/inference.pdmodel
  56. 54 0
      ocr/my_infer.py
  57. 75 0
      ocr/my_infer_hub.py
  58. 155 0
      ocr/ocr_interface.py
  59. 362 0
      ocr/paddleocr.py
  60. 13 0
      ocr/ppocr/__init__.py
  61. 110 0
      ocr/ppocr/data/__init__.py
  62. 62 0
      ocr/ppocr/data/imaug/__init__.py
  63. 439 0
      ocr/ppocr/data/imaug/east_process.py
  64. 101 0
      ocr/ppocr/data/imaug/iaa_augment.py
  65. 281 0
      ocr/ppocr/data/imaug/label_ops.py
  66. 157 0
      ocr/ppocr/data/imaug/make_border_map.py
  67. 107 0
      ocr/ppocr/data/imaug/make_shrink_map.py
  68. 225 0
      ocr/ppocr/data/imaug/operators.py
  69. 140 0
      ocr/ppocr/data/imaug/randaugment.py
  70. 210 0
      ocr/ppocr/data/imaug/random_crop_data.py
  71. 435 0
      ocr/ppocr/data/imaug/rec_img_aug.py
  72. 774 0
      ocr/ppocr/data/imaug/sast_process.py
  73. 17 0
      ocr/ppocr/data/imaug/text_image_aug/__init__.py
  74. 116 0
      ocr/ppocr/data/imaug/text_image_aug/augment.py
  75. 164 0
      ocr/ppocr/data/imaug/text_image_aug/warp_mls.py
  76. 115 0
      ocr/ppocr/data/lmdb_dataset.py
  77. 126 0
      ocr/ppocr/data/simple_dataset.py
  78. 455 0
      ocr/ppocr/data/text2Image.py
  79. 42 0
      ocr/ppocr/losses/__init__.py
  80. 30 0
      ocr/ppocr/losses/cls_loss.py
  81. 205 0
      ocr/ppocr/losses/det_basic_loss.py
  82. 72 0
      ocr/ppocr/losses/det_db_loss.py
  83. 63 0
      ocr/ppocr/losses/det_east_loss.py
  84. 121 0
      ocr/ppocr/losses/det_sast_loss.py
  85. 39 0
      ocr/ppocr/losses/rec_att_loss.py
  86. 36 0
      ocr/ppocr/losses/rec_ctc_loss.py
  87. 47 0
      ocr/ppocr/losses/rec_srn_loss.py
  88. 37 0
      ocr/ppocr/metrics/__init__.py
  89. 45 0
      ocr/ppocr/metrics/cls_metric.py
  90. 72 0
      ocr/ppocr/metrics/det_metric.py
  91. 235 0
      ocr/ppocr/metrics/eval_det_iou.py
  92. 62 0
      ocr/ppocr/metrics/rec_metric.py
  93. 25 0
      ocr/ppocr/modeling/architectures/__init__.py
  94. 85 0
      ocr/ppocr/modeling/architectures/base_model.py
  95. 37 0
      ocr/ppocr/modeling/backbones/__init__.py
  96. 287 0
      ocr/ppocr/modeling/backbones/det_mobilenet_v3.py
  97. 280 0
      ocr/ppocr/modeling/backbones/det_resnet_vd.py
  98. 285 0
      ocr/ppocr/modeling/backbones/det_resnet_vd_sast.py
  99. 146 0
      ocr/ppocr/modeling/backbones/rec_mobilenet_v3.py
  100. 307 0
      ocr/ppocr/modeling/backbones/rec_resnet_fpn.py

+ 10 - 7
.gitignore

@@ -1,15 +1,17 @@
-/container.py
-/container.py
-/.idea/
-/convert_so/
-/libreoffice/
 /package_2021_12_28/
-/package_old/
 /usr/
+/ossUtils.py
+/fonts/
+/package_2021_12_28/
+/container.py
 /fonts.conf
 /libmergedlo.so
+/format_conversion_maxcompute.iml
+/package_old/
+/libreoffice/
+/.idea/
 /libuno_sal.so.3
-/ossUtils.py
+/convert_so/
 /so.env
 /stderr (4).txt
 /stream.py
@@ -17,3 +19,4 @@
 /test1.py
 /unrar
 /wiki_128_word_embedding_new.env
+/yep_homework.py

+ 0 - 0
format_convert/UnRAR.exe


+ 2459 - 0
format_convert/convert.py

@@ -0,0 +1,2459 @@
+#-*- coding: utf-8 -*-
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+import codecs
+import gc
+import hashlib
+import io
+import json
+import multiprocessing
+import sys
+import subprocess
+
+import PyPDF2
+import lxml
+import pdfminer
+from PIL import Image
+
+from format_convert import get_memory_info
+from ocr import ocr_interface
+from ocr.ocr_interface import ocr, OcrModels
+from otr import otr_interface
+from otr.otr_interface import otr, OtrModels
+import re
+import shutil
+import signal
+import sys
+import base64
+import time
+import traceback
+import uuid
+from os.path import basename
+import cv2
+import fitz
+import pandas
+import docx
+import zipfile
+import mimetypes
+import filetype
+# import pdfplumber
+import psutil
+import requests
+import rarfile
+from PyPDF2 import PdfFileReader, PdfFileWriter
+import xml.dom.minidom
+import subprocess
+import logging
+from pdfminer.pdfparser import PDFParser
+from pdfminer.pdfdocument import PDFDocument
+from pdfminer.pdfpage import PDFPage
+from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
+from pdfminer.converter import PDFPageAggregator
+from pdfminer.layout import LTTextBoxHorizontal, LAParams, LTFigure, LTImage, LTCurve, LTText, LTChar
+import logging
+import chardet
+from bs4 import BeautifulSoup
+from format_convert.libreoffice_interface import office_convert
+from format_convert.swf.export import SVGExporter
+logging.getLogger("pdfminer").setLevel(logging.WARNING)
+from format_convert.table_correct import *
+from format_convert.swf.movie import SWF
+import logging
+# import timeout_decorator
+from format_convert import timeout_decorator
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+# txt doc docx xls xlsx pdf zip rar swf jpg jpeg png
+
+
+def judge_error_code(_list, code=[-1, -2, -3, -4, -5, -7]):
+    for c in code:
+        if _list == [c]:
+            return True
+    return False
+
+
+def set_timeout(signum, frame):
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+    print("=======================set_timeout")
+
+    raise TimeoutError
+
+
+def log_traceback(func_name):
+    logging.info(func_name)
+    etype, value, tb = sys.exc_info()
+    for line in traceback.TracebackException(
+            type(value), value, tb, limit=None).format(chain=True):
+        logging.info(line)
+
+
+def judge_format(path):
+    guess1 = mimetypes.guess_type(path)
+    _type = None
+    if guess1[0]:
+        _type = guess1[0]
+    else:
+        guess2 = filetype.guess(path)
+        if guess2:
+            _type = guess2.mime
+
+    if _type == "application/pdf":
+        return "pdf"
+    if _type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
+        return "docx"
+    if _type == "application/x-zip-compressed" or _type == "application/zip":
+        return "zip"
+    if _type == "application/x-rar-compressed" or _type == "application/rar":
+        return "rar"
+    if _type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
+        return "xlsx"
+    if _type == "application/msword":
+        return "doc"
+    if _type == "image/png":
+        return "png"
+    if _type == "image/jpeg":
+        return "jpg"
+
+    # 猜不到,返回None
+    return None
+
+
+@get_memory_info.memory_decorator
+def txt2text(path):
+    logging.info("into txt2text")
+    try:
+        # 判断字符编码
+        with open(path, "rb") as ff:
+            data = ff.read()
+        encode = chardet.detect(data).get("encoding")
+        print("txt2text judge code is", encode)
+
+        try:
+            if encode is None:
+                logging.info("txt2text cannot judge file code!")
+                return [-3]
+            with open(path, "r", encoding=encode) as ff:
+                txt_text = ff.read()
+            return [txt_text]
+        except:
+            logging.info("txt2text cannot open file with code " + encode)
+            return [-3]
+    except Exception as e:
+        print("txt2text", traceback.print_exc())
+        logging.info("txt2text error!")
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def doc2text(path, unique_type_dir):
+    logging.info("into doc2text")
+    try:
+        # 调用office格式转换
+        file_path = from_office_interface(path, unique_type_dir, 'docx')
+        # if file_path == [-3]:
+        #     return [-3]
+        if judge_error_code(file_path):
+            return file_path
+
+        text = docx2text(file_path, unique_type_dir)
+        return text
+    except Exception as e:
+        logging.info("doc2text error!")
+        print("doc2text", traceback.print_exc())
+        # log_traceback("doc2text")
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def read_xml_order(path, save_path):
+    logging.info("into read_xml_order")
+    try:
+        try:
+            f = zipfile.ZipFile(path)
+            for file in f.namelist():
+                if "word/document.xml" == str(file):
+                    f.extract(file, save_path)
+            f.close()
+        except Exception as e:
+            # print("docx format error!", e)
+            logging.info("docx format error!")
+            return [-3]
+
+        # DOMTree = xml.dom.minidom.parse(save_path + "word/document.xml")
+        # collection = DOMTree.documentElement
+
+        try:
+            collection = xml_analyze(save_path + "word/document.xml")
+        except TimeoutError:
+            logging.info("read_xml_order timeout")
+            return [-4]
+
+        body = collection.getElementsByTagName("w:body")[0]
+        order_list = []
+        for line in body.childNodes:
+            # print(str(line))
+            if "w:p" in str(line):
+                text = line.getElementsByTagName("w:t")
+                picture = line.getElementsByTagName("wp:docPr")
+                if text:
+                    order_list.append("w:t")
+                if picture:
+                    order_list.append("wp:docPr")
+
+                for line1 in line.childNodes:
+                    if "w:r" in str(line1):
+                        # print("read_xml_order", "w:r")
+                        picture1 = line1.getElementsByTagName("w:pict")
+                        if picture1:
+                            order_list.append("wp:docPr")
+
+            if "w:tbl" in str(line):
+                order_list.append("w:tbl")
+        read_xml_table(path, save_path)
+        return order_list
+    except Exception as e:
+        logging.info("read_xml_order error!")
+        print("read_xml_order", traceback.print_exc())
+        # log_traceback("read_xml_order")
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def read_xml_table(path, save_path):
+    logging.info("into read_xml_table")
+    try:
+        # print("into read_xml_table")
+        try:
+            f = zipfile.ZipFile(path)
+            for file in f.namelist():
+                if "word/document.xml" == str(file):
+                    f.extract(file, save_path)
+            f.close()
+        except Exception as e:
+            # print("docx format error!", e)
+            logging.info("docx format error!")
+            return [-3]
+
+        # DOMTree = xml.dom.minidom.parse(save_path + "word/document.xml")
+        # collection = DOMTree.documentElement
+
+        try:
+            collection = xml_analyze(save_path + "word/document.xml")
+        except TimeoutError:
+            logging.info("read_xml_table timeout")
+            return [-4]
+
+        body = collection.getElementsByTagName("w:body")[0]
+        table_text_list = []
+        # print("body.childNodes", body.childNodes)
+        for line in body.childNodes:
+            if "w:tbl" in str(line):
+                # print("str(line)", str(line))
+                table_text = '<table border="1">' + "\n"
+                tr_list = line.getElementsByTagName("w:tr")
+                # print("line.childNodes", line.childNodes)
+                tr_index = 0
+                tr_text_list = []
+                tr_text_list_colspan = []
+                for tr in tr_list:
+                    table_text = table_text + "<tr rowspan=1>" + "\n"
+                    tc_list = tr.getElementsByTagName("w:tc")
+                    tc_index = 0
+                    tc_text_list = []
+                    for tc in tc_list:
+                        tc_text = ""
+
+                        # 获取一格占多少列
+                        col_span = tc.getElementsByTagName("w:gridSpan")
+                        if col_span:
+                            col_span = int(col_span[0].getAttribute("w:val"))
+                        else:
+                            col_span = 1
+
+                        # 获取是否是合并单元格的下一个空单元格
+                        is_merge = tc.getElementsByTagName("w:vMerge")
+                        if is_merge:
+                            is_merge = is_merge[0].getAttribute("w:val")
+                            if is_merge == "continue":
+                                col_span_index = 0
+                                real_tc_index = 0
+
+                                # if get_platform() == "Windows":
+                                #     print("read_xml_table tr_text_list", tr_text_list)
+                                #     print("read_xml_table tr_index", tr_index)
+
+                                if 0 <= tr_index - 1 < len(tr_text_list):
+                                    for tc_colspan in tr_text_list[tr_index - 1]:
+                                        if col_span_index < tc_index:
+                                            col_span_index += tc_colspan[1]
+                                            real_tc_index += 1
+
+                                    # print("tr_index-1, real_tc_index", tr_index-1, real_tc_index)
+                                    # print(tr_text_list[tr_index-1])
+                                    if real_tc_index < len(tr_text_list[tr_index - 1]):
+                                        tc_text = tr_text_list[tr_index - 1][real_tc_index][0]
+
+                        table_text = table_text + "<td colspan=" + str(col_span) + ">" + "\n"
+                        p_list = tc.getElementsByTagName("w:p")
+
+                        for p in p_list:
+                            t = p.getElementsByTagName("w:t")
+                            if t:
+                                for tt in t:
+                                    # print("tt", tt.childNodes)
+                                    if len(tt.childNodes) > 0:
+                                        tc_text += tt.childNodes[0].nodeValue
+                                tc_text += "\n"
+
+                        table_text = table_text + tc_text + "</td>" + "\n"
+                        tc_index += 1
+                        tc_text_list.append([tc_text, col_span])
+                    table_text += "</tr>" + "\n"
+                    tr_index += 1
+                    tr_text_list.append(tc_text_list)
+                table_text += "</table>" + "\n"
+                table_text_list.append(table_text)
+        return table_text_list
+
+    except Exception as e:
+        logging.info("read_xml_table error")
+        print("read_xml_table", traceback.print_exc())
+        # log_traceback("read_xml_table")
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+@timeout_decorator.timeout(300, timeout_exception=TimeoutError)
+def xml_analyze(path):
+    # 解析xml
+    DOMTree = xml.dom.minidom.parse(path)
+    collection = DOMTree.documentElement
+    return collection
+
+
+def read_docx_table(document):
+    table_text_list = []
+    for table in document.tables:
+        table_text = "<table>\n"
+        print("==================")
+        for row in table.rows:
+            table_text += "<tr>\n"
+            for cell in row.cells:
+                table_text += "<td>" + cell.text + "</td>\n"
+            table_text += "</tr>\n"
+        table_text += "</table>\n"
+        print(table_text)
+        table_text_list.append(table_text)
+    return table_text_list
+
+
+@get_memory_info.memory_decorator
+def docx2text(path, unique_type_dir):
+    logging.info("into docx2text")
+    try:
+        try:
+            doc = docx.Document(path)
+        except Exception as e:
+            print("docx format error!", e)
+            print(traceback.print_exc())
+            logging.info("docx format error!")
+            return [-3]
+
+        # 遍历段落
+        # print("docx2text extract paragraph")
+        paragraph_text_list = []
+        for paragraph in doc.paragraphs:
+            if paragraph.text != "":
+                paragraph_text_list.append("<div>" + paragraph.text + "</div>" + "\n")
+                # print("paragraph_text", paragraph.text)
+
+        # 遍历表
+        try:
+            table_text_list = read_xml_table(path, unique_type_dir)
+        except TimeoutError:
+            return [-4]
+
+        if judge_error_code(table_text_list):
+            return table_text_list
+
+        # 顺序遍历图片
+        # print("docx2text extract image")
+        image_text_list = []
+        temp_image_path = unique_type_dir + "temp_image.png"
+        pattern = re.compile('rId\d+')
+        for graph in doc.paragraphs:
+            for run in graph.runs:
+                if run.text == '':
+                    try:
+                        if not pattern.search(run.element.xml):
+                            continue
+                        content_id = pattern.search(run.element.xml).group(0)
+                        content_type = doc.part.related_parts[content_id].content_type
+                    except Exception as e:
+                        print("docx no image!", e)
+                        continue
+                    if not content_type.startswith('image'):
+                        continue
+
+                    # 写入临时文件
+                    img_data = doc.part.related_parts[content_id].blob
+                    with open(temp_image_path, 'wb') as f:
+                        f.write(img_data)
+
+                    # if get_platform() == "Windows":
+                    #     print("img_data", img_data)
+
+                    if img_data is None:
+                        continue
+
+                    # 识别图片文字
+                    image_text = picture2text(temp_image_path)
+                    if image_text == [-2]:
+                        return [-2]
+                    if image_text == [-1]:
+                        return [-1]
+                    if image_text == [-3]:
+                        continue
+
+                    image_text = image_text[0]
+                    image_text_list.append(add_div(image_text))
+
+        # 解析document.xml,获取文字顺序
+        # print("docx2text extract order")
+        order_list = read_xml_order(path, unique_type_dir)
+        if order_list == [-2]:
+            return [-2]
+        if order_list == [-1]:
+            return [-1]
+
+        text = ""
+        print("len(order_list)", len(order_list))
+        print("len(paragraph_text_list)", len(paragraph_text_list))
+        print("len(image_text_list)", len(image_text_list))
+        print("len(table_text_list)", len(table_text_list))
+
+        # log("docx2text output in order")
+        for tag in order_list:
+            if tag == "w:t":
+                if len(paragraph_text_list) > 0:
+                    text += paragraph_text_list.pop(0)
+            if tag == "wp:docPr":
+                if len(image_text_list) > 0:
+                    text += image_text_list.pop(0)
+            if tag == "w:tbl":
+                if len(table_text_list) > 0:
+                    text += table_text_list.pop(0)
+        return [text]
+    except Exception as e:
+        # print("docx2text", e, global_type)
+        logging.info("docx2text error!")
+        print("docx2text", traceback.print_exc())
+        # log_traceback("docx2text")
+        return [-1]
+
+
+def add_div(text):
+    if text == "" or text is None:
+        return text
+
+    if get_platform() == "Windows":
+        print("add_div", text)
+    if re.findall("<div>", text):
+        return text
+
+    text = "<div>" + text + "\n"
+    text = re.sub("\n", "</div>\n<div>", text)
+    # text += "</div>"
+    if text[-5:] == "<div>":
+        print("add_div has cut", text[-30:])
+        text = text[:-5]
+    return text
+
+
+@get_memory_info.memory_decorator
+def pdf2Image(path, save_dir):
+    logging.info("into pdf2Image")
+    try:
+        try:
+            doc = fitz.open(path)
+        except Exception as e:
+            logging.info("pdf format error!")
+            # print("pdf format error!", e)
+            return [-3]
+
+        output_image_list = []
+        for page_no in range(doc.page_count):
+            # 限制pdf页数,只取前100页
+            if page_no >= 70:
+                logging.info("pdf2Image: pdf pages count " + str(doc.page_count)
+                             + ", only get 70 pages")
+                break
+
+            try:
+                page = doc.loadPage(page_no)
+                output = save_dir + "_page" + str(page_no) + ".png"
+                rotate = int(0)
+                # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。
+                # 此处若是不做设置,默认图片大小为:792X612, dpi=96
+                # (1.33333333-->1056x816)   (2-->1584x1224)
+                zoom_x = 1.33333333
+                zoom_y = 1.33333333
+                # mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
+                mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
+                pix = page.getPixmap(matrix=mat, alpha=False)
+                pix.writePNG(output)
+                output_image_list.append(output)
+            except ValueError as e:
+                traceback.print_exc()
+                if str(e) == "page not in document":
+                    logging.info("pdf2Image page not in document! continue..." + str(page_no))
+                    continue
+                elif "encrypted" in str(e):
+                    logging.info("pdf2Image document need password " + str(page_no))
+                    return [-7]
+            except RuntimeError as e:
+                if "cannot find page" in str(e):
+                    logging.info("pdf2Image page {} not in document! continue... ".format(str(page_no)) + str(e))
+                    continue
+                else:
+                    traceback.print_exc()
+                    return [-3]
+        return output_image_list
+
+    except Exception as e:
+        logging.info("pdf2Image error!")
+        print("pdf2Image", traceback.print_exc())
+        return [-1]
+
+
+def image_preprocess(image_np, image_path, use_ocr=True):
+    logging.info("into image_preprocess")
+    try:
+        # 长 宽
+        # resize_size = (1024, 768)
+        # 限制图片大小
+        # resize_image(image_path, resize_size)
+
+        # 图片倾斜校正,写入原来的图片路径
+        g_r_i = get_rotated_image(image_np, image_path)
+        if g_r_i == [-1]:
+            return [-1], [], [], 0
+
+        # otr需要图片resize, 写入另一个路径
+        image_np = cv2.imread(image_path)
+        best_h, best_w = get_best_predict_size(image_np)
+        image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
+        image_resize_path = image_path[:-4] + "_resize" + image_path[-4:]
+        cv2.imwrite(image_resize_path, image_resize)
+
+        # 调用otr模型接口
+        with open(image_resize_path, "rb") as f:
+            image_bytes = f.read()
+        points, split_lines, bboxes, outline_points = from_otr_interface(image_bytes)
+        if judge_error_code(points):
+            return points, [], [], 0
+
+        # 将resize后得到的bbox根据比例还原
+        ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
+        for i in range(len(bboxes)):
+            bbox = bboxes[i]
+            bboxes[i] = [(int(bbox[0][0]*ratio[1]), int(bbox[0][1]*ratio[0])),
+                         (int(bbox[1][0]*ratio[1]), int(bbox[1][1]*ratio[0]))]
+        # 查看是否能输出正确框
+        # for box in bboxes:
+        #     cv2.rectangle(image_np, box[0], box[1], (0, 255, 0), 3)
+        #     cv2.imshow("bbox", image_np)
+        #     cv2.waitKey(0)
+
+        # 调用ocr模型接口
+        with open(image_path, "rb") as f:
+            image_bytes = f.read()
+        # 有表格
+        if len(bboxes) >= 2:
+            text_list, bbox_list = from_ocr_interface(image_bytes, True)
+            if judge_error_code(text_list):
+                return text_list, [], [], 0
+
+            # for i in range(len(text_list)):
+            #     print(text_list[i], bbox_list[i])
+            # 查看是否能输出正确框
+            # for box in bbox_list:
+            #     cv2.rectangle(image_np, (int(box[0][0]), int(box[0][1])),
+            #                   (int(box[2][0]), int(box[2][1])), (0, 255, 0), 1)
+            #     cv2.imshow("bbox", image_np)
+            #     cv2.waitKey(0)
+
+            text, column_list = get_formatted_table(text_list, bbox_list, bboxes, split_lines)
+            if judge_error_code(text):
+                return text, [], [], 0
+            is_table = 1
+            return text, column_list, outline_points, is_table
+
+        # 无表格
+        else:
+            if use_ocr:
+                text = from_ocr_interface(image_bytes)
+                if judge_error_code(text):
+                    return text, [], [], 0
+
+                is_table = 0
+                return text, [], [], is_table
+            else:
+                is_table = 0
+                return None, [], [], is_table
+
+    except Exception as e:
+        logging.info("image_preprocess error")
+        print("image_preprocess", traceback.print_exc())
+        return [-1], [], [], 0
+
+
+def get_best_predict_size(image_np):
+    sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
+
+    min_len = 10000
+    best_height = sizes[0]
+    for height in sizes:
+        if abs(image_np.shape[0] - height) < min_len:
+            min_len = abs(image_np.shape[0] - height)
+            best_height = height
+
+    min_len = 10000
+    best_width = sizes[0]
+    for width in sizes:
+        if abs(image_np.shape[1] - width) < min_len:
+            min_len = abs(image_np.shape[1] - width)
+            best_width = width
+
+    return best_height, best_width
+
+
+@get_memory_info.memory_decorator
+def pdf2text(path, unique_type_dir):
+    logging.info("into pdf2text")
+    try:
+        # pymupdf pdf to image
+        save_dir = path.split(".")[-2] + "_" + path.split(".")[-1]
+        output_image_list = pdf2Image(path, save_dir)
+        if judge_error_code(output_image_list):
+            return output_image_list
+
+        # 获取每页pdf提取的文字、表格的列数、轮廓点、是否含表格、页码
+        page_info_list = []
+        page_no = 0
+        for img_path in output_image_list:
+            print("pdf page", page_no, "in total", len(output_image_list))
+            # 读不出来的跳过
+            try:
+                img = cv2.imread(img_path)
+                img_size = img.shape
+            except:
+                logging.info("pdf2text read image in page fail! continue...")
+                continue
+            # print("pdf2text img_size", img_size)
+
+            text, column_list, outline_points, is_table = image_preprocess(img, img_path,
+                                                                           use_ocr=False)
+            if judge_error_code(text):
+                return text
+
+            page_info_list.append([text, column_list, outline_points, is_table,
+                                   page_no, img_size])
+            page_no += 1
+        # print("pdf2text", page_info_list)
+
+        # 包含table的和不包含table的
+        has_table_list = []
+        has_table_page_no_list = []
+        no_table_list = []
+        no_table_page_no_list = []
+        for page_info in page_info_list:
+            # 含表格不含表格分开
+            if not page_info[3]:
+                no_table_list.append(page_info)
+                no_table_page_no_list.append(page_info[4])
+            else:
+                has_table_list.append(page_info)
+                has_table_page_no_list.append(page_info[4])
+
+        # 页码表格连接
+        table_connect_list, connect_text_list = page_table_connect(has_table_list,
+                                                                   page_info_list)
+        # table_connect_list, connect_text_list = [], []
+        if judge_error_code(table_connect_list):
+            return table_connect_list
+
+        # 连接的页码
+        table_connect_page_no_list = []
+        for area in connect_text_list:
+            table_connect_page_no_list.append(area[1])
+        # print("pdf2text table_connect_list", table_connect_list)
+
+        # pdfminer 方式
+        try:
+            fp = open(path, 'rb')
+            # 用文件对象创建一个PDF文档分析器
+            parser = PDFParser(fp)
+            # 创建一个PDF文档
+            doc = PDFDocument(parser)
+            # 连接分析器,与文档对象
+            rsrcmgr = PDFResourceManager()
+            device = PDFPageAggregator(rsrcmgr, laparams=LAParams())
+            interpreter = PDFPageInterpreter(rsrcmgr, device)
+
+            # 判断是否能读pdf
+            for page in PDFPage.create_pages(doc):
+                break
+        except pdfminer.psparser.PSEOF as e:
+            # pdfminer 读不了空白页的对象,直接使用pymupdf转换出的图片进行ocr识别
+            logging.info("pdf2text " + str(e) + " use ocr read pdf!")
+            text_list = []
+            for page_info in page_info_list:
+                page_no = page_info[4]
+                # 表格
+                if page_info[3]:
+                    # 判断表格是否跨页连接
+                    area_no = 0
+                    jump_page = 0
+                    for area in table_connect_list:
+                        if page_no in area:
+                            # 只记录一次text
+                            if page_no == area[0]:
+                                image_text = connect_text_list[area_no][0]
+                                text_list.append([image_text, page_no, 0])
+                            jump_page = 1
+                        area_no += 1
+
+                    # 是连接页的跳过后面步骤
+                    if jump_page:
+                        continue
+
+                    # 直接取text
+                    image_text = page_info_list[page_no][0]
+                    text_list.append([image_text, page_no, 0])
+                # 非表格
+                else:
+                    with open(output_image_list[page_no], "rb") as ff:
+                        image_stream = ff.read()
+                    image_text = from_ocr_interface(image_stream)
+                    text_list.append([image_text, page_no, 0])
+
+            text_list.sort(key=lambda z: z[1])
+            text = ""
+            for t in text_list:
+                text += t[0]
+            return [text]
+        except Exception as e:
+            logging.info("pdf format error!")
+            traceback.print_exc()
+            return [-3]
+
+        text_list = []
+        page_no = 0
+
+        pages = PDFPage.create_pages(doc)
+        for page in pages:
+            logging.info("pdf2text page_no " + str(page_no))
+            # 限制pdf页数,只取前100页
+            if page_no >= 70:
+                logging.info("pdf2text: pdf pages only get 100 pages")
+                break
+
+            # 判断页码在含表格页码中,直接拿已生成的text
+            if page_no in has_table_page_no_list:
+                # 判断表格是否跨页连接
+                area_no = 0
+                jump_page = 0
+                for area in table_connect_list:
+                    if page_no in area:
+                        # 只记录一次text
+                        if page_no == area[0]:
+                            image_text = connect_text_list[area_no][0]
+                            text_list.append([image_text, page_no, 0])
+                        jump_page = 1
+                    area_no += 1
+
+                # 是连接页的跳过后面步骤
+                if jump_page:
+                    page_no += 1
+                    continue
+
+                # 直接取text
+                image_text = page_info_list[page_no][0]
+                text_list.append([image_text, page_no, 0])
+                page_no += 1
+                continue
+
+            # 不含表格的解析pdf
+            else:
+                if get_platform() == "Windows":
+                    try:
+                        interpreter.process_page(page)
+                        layout = device.get_result()
+                    except Exception:
+                        logging.info("pdf2text pdfminer read pdf page error! continue...")
+                        continue
+
+                else:
+                    # 设置超时时间
+                    try:
+                        # 解析pdf中的不含表格的页
+                        if get_platform() == "Windows":
+                            origin_pdf_analyze = pdf_analyze.__wrapped__
+                            layout = origin_pdf_analyze(interpreter, page, device)
+                        else:
+                            layout = pdf_analyze(interpreter, page, device)
+                    except TimeoutError as e:
+                        logging.info("pdf2text pdfminer read pdf page time out!")
+                        return [-4]
+                    except Exception:
+                        logging.info("pdf2text pdfminer read pdf page error! continue...")
+                        continue
+
+                # 判断该页有没有文字对象,没有则有可能是有水印
+                only_image = 1
+                image_count = 0
+                for x in layout:
+                    if isinstance(x, LTTextBoxHorizontal):
+                        only_image = 0
+                    if isinstance(x, LTFigure):
+                        image_count += 1
+
+                # 如果该页图片数量过多,直接ocr整页识别
+                logging.info("pdf2text image_count" + str(image_count))
+                if image_count >= 3:
+                    with open(output_image_list[page_no], "rb") as ff:
+                        image_stream = ff.read()
+                    image_text = from_ocr_interface(image_stream)
+
+                    if judge_error_code(image_text):
+                        return image_text
+
+                    text_list.append([image_text, page_no, 0])
+                    page_no += 1
+                    continue
+
+                order_list = []
+                for x in layout:
+                    if get_platform() == "Windows":
+                        # print("x", page_no, x)
+                        print()
+
+                    if isinstance(x, LTTextBoxHorizontal):
+                        image_text = x.get_text()
+
+                        # 无法识别编码,用ocr
+                        if re.search('[(]cid:[0-9]+[)]', image_text):
+                            print(re.search('[(]cid:[0-9]+[)]', image_text))
+                            with open(output_image_list[page_no], "rb") as ff:
+                                image_stream = ff.read()
+                            image_text = from_ocr_interface(image_stream)
+                            if judge_error_code(image_text):
+                                return image_text
+                            image_text = add_div(image_text)
+                            order_list.append([image_text, page_no, x.bbox[1]])
+                            break
+                        else:
+                            image_text = add_div(image_text)
+                            order_list.append([image_text, page_no, x.bbox[1]])
+                            continue
+
+                    if isinstance(x, LTFigure):
+                        for image in x:
+                            if isinstance(image, LTImage):
+                                try:
+                                    print(image.width, image.height)
+                                    image_stream = image.stream.get_data()
+
+                                    # 有些水印导致pdf分割、读取报错
+                                    # if image.width <= 200 and image.height<=200:
+                                    #     continue
+
+                                    # img_test = Image.open(io.BytesIO(image_stream))
+                                    # img_test.save('temp/LTImage.jpg')
+
+                                    # 查看提取的图片高宽,太大则抛错用另一张图
+                                    img_test = Image.open(io.BytesIO(image_stream))
+                                    if img_test.size[1] > 2000 or img_test.size[0] > 1500:
+                                        print("pdf2text LTImage size", img_test.size)
+                                        raise Exception
+                                    img_test.save('temp/LTImage.jpg')
+
+                                # except pdfminer.pdftypes.PDFNotImplementedError:
+                                #     with open(output_image_list[page_no], "rb") as ff:
+                                #         image_stream = ff.read()
+                                except Exception:
+                                    logging.info("pdf2text pdfminer read image in page fail! use pymupdf read image...")
+                                    print(traceback.print_exc())
+                                    with open(output_image_list[page_no], "rb") as ff:
+                                        image_stream = ff.read()
+                                image_text = from_ocr_interface(image_stream)
+
+                                if judge_error_code(image_text):
+                                    return image_text
+
+                                # 判断只拿到了水印图: 无文字输出且只有图片对象
+                                if image_text == "" and only_image:
+                                    # 拆出该页pdf
+                                    try:
+                                        logging.info("pdf2text guess pdf has watermark")
+                                        split_path = get_single_pdf(path, page_no)
+                                    except:
+                                        # 如果拆分抛异常,则大概率不是水印图,用ocr识别图片
+                                        logging.info("pdf2text guess pdf has no watermark")
+                                        with open(output_image_list[page_no], "rb") as ff:
+                                            image_stream = ff.read()
+                                            image_text = from_ocr_interface(image_stream)
+                                        image_text = image_text
+                                        order_list.append([image_text, page_no, x.bbox[1]])
+                                        continue
+                                    if judge_error_code(split_path):
+                                        return split_path
+
+                                    # 调用office格式转换
+                                    file_path = from_office_interface(split_path, unique_type_dir, 'html', 3)
+                                    # if file_path == [-3]:
+                                    #     return [-3]
+                                    if judge_error_code(file_path):
+                                        return file_path
+
+                                    # 获取html文本
+                                    image_text = get_html_p(file_path)
+                                    if judge_error_code(image_text):
+                                        return image_text
+
+                                if get_platform() == "Windows":
+                                    print("image_text", page_no, image_text)
+                                    with open("temp" + str(x.bbox[0]) + ".jpg", "wb") as ff:
+                                        ff.write(image_stream)
+
+                                image_text = add_div(image_text)
+                                order_list.append([image_text, page_no, x.bbox[1]])
+
+                if get_platform() == "Windows":
+                    print("order_list", page_no, order_list)
+
+                order_list.sort(key=lambda z: z[2], reverse=True)
+                text_list += order_list
+                page_no += 1
+
+        text = ""
+        for t in text_list:
+            # text += add_div(t[0])
+            text += t[0]
+        return [text]
+    except UnicodeDecodeError as e:
+        logging.info("pdf2text pdfminer create pages failed! " + str(e))
+        return [-3]
+    except Exception as e:
+        logging.info("pdf2text error!")
+        print("pdf2text", traceback.print_exc())
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+@timeout_decorator.timeout(300, timeout_exception=TimeoutError)
+def pdf_analyze(interpreter, page, device):
+    logging.info("into pdf_analyze")
+    # 解析pdf中的不含表格的页
+    pdf_time = time.time()
+    print("pdf_analyze interpreter process...")
+    interpreter.process_page(page)
+    print("pdf_analyze device get_result...")
+    layout = device.get_result()
+    logging.info("pdf2text read time " + str(time.time()-pdf_time))
+    return layout
+
+
+def get_html_p(html_path):
+    logging.info("into get_html_p")
+    try:
+        with open(html_path, "r") as ff:
+            html_str = ff.read()
+
+        soup = BeautifulSoup(html_str, 'lxml')
+        text = ""
+        for p in soup.find_all("p"):
+            p_text = p.text
+            p_text = p_text.strip()
+            if p.string != "":
+                text += p_text
+        text += "\n"
+        return text
+    except Exception as e:
+        logging.info("get_html_p error!")
+        print("get_html_p", traceback.print_exc())
+        return [-1]
+
+
+def get_single_pdf(path, page_no):
+    logging.info("into get_single_pdf")
+    try:
+        # print("path, ", path)
+        pdf_origin = PdfFileReader(path, strict=False)
+
+        pdf_new = PdfFileWriter()
+        pdf_new.addPage(pdf_origin.getPage(page_no))
+
+        path_new = path.split(".")[0] + "_split.pdf"
+        with open(path_new, "wb") as ff:
+            pdf_new.write(ff)
+        return path_new
+    except PyPDF2.utils.PdfReadError as e:
+        raise e
+    except Exception as e:
+        logging.info("get_single_pdf error! page " + str(page_no))
+        print("get_single_pdf", traceback.print_exc())
+        raise e
+
+
+def page_table_connect(has_table_list, page_info_list):
+    logging.info("into page_table_connect")
+    try:
+        # 判断是否有页码的表格相连
+        table_connect_list = []
+        temp_list = []
+        # 离图片顶部或底部距离
+        threshold = 100
+
+        for i in range(1, len(has_table_list)):
+            page_info = has_table_list[i]
+            last_page_info = has_table_list[i - 1]
+
+            # 页码需相连
+            if page_info[4] - last_page_info[4] == 1:
+
+                # 上一页的最后一个列数和下一页的第一个列数都为0,且相等
+                if not last_page_info[1][-1] and not page_info[1][0] and \
+                        last_page_info[1][-1] == page_info[1][0]:
+
+                    # 上一页的轮廓点要离底部一定距离内,下一页的轮廓点要离顶部一定距离内
+                    if page_info[5][0] - last_page_info[2][-1][1][1] <= threshold and \
+                            page_info[2][0][0][1] - 0 <= 100:
+                        # print("page_table_connect accept")
+                        temp_list.append(last_page_info[4])
+                        temp_list.append(page_info[4])
+                        continue
+
+            # 条件不符合的,存储之前保存的连接页码
+            if len(temp_list) > 1:
+                temp_list = list(set(temp_list))
+                temp_list.sort(key=lambda x: x)
+                table_connect_list.append(temp_list)
+                temp_list = []
+        if len(temp_list) > 1:
+            temp_list = list(set(temp_list))
+            temp_list.sort(key=lambda x: x)
+            table_connect_list.append(temp_list)
+            temp_list = []
+
+        # 连接两页内容
+        connect_text_list = []
+        for area in table_connect_list:
+            first_page_no = area[0]
+            area_page_text = str(page_info_list[first_page_no][0])
+            # print("area_page_text", area_page_text)
+            for i in range(1, len(area)):
+                current_page_no = area[i]
+                current_page_text = page_info_list[current_page_no][0]
+
+                # 连接两个table
+                table_prefix = re.finditer('<table border="1">', current_page_text)
+                index_list = []
+                for t in table_prefix:
+                    index_list.append(t.span())
+
+                delete_index = index_list[0]
+                current_page_text = current_page_text[:delete_index[0]] \
+                                    + current_page_text[delete_index[1]:]
+                # current_page_text = current_page_text[18:]
+                # print("current_page_text", current_page_text[:30])
+                # print("current_page_text", current_page_text)
+
+                table_suffix = re.finditer('</table>', area_page_text)
+                index_list = []
+                for t in table_suffix:
+                    index_list.append(t.span())
+
+                delete_index = index_list[-1]
+                area_page_text = area_page_text[:delete_index[0]] \
+                                    + area_page_text[delete_index[1]:]
+                # area_page_text = area_page_text[:-9]
+                # print("area_page_text", area_page_text[-20:])
+                area_page_text = area_page_text + current_page_text
+
+            connect_text_list.append([area_page_text, area])
+
+        return table_connect_list, connect_text_list
+    except Exception as e:
+        # print("page_table_connect", e)
+        logging.info("page_table_connect error!")
+        print("page_table_connect", traceback.print_exc())
+        return [-1], [-1]
+
+
+@get_memory_info.memory_decorator
+def zip2text(path, unique_type_dir):
+    logging.info("into zip2text")
+    try:
+        zip_path = unique_type_dir
+
+        try:
+            zip_file = zipfile.ZipFile(path)
+            zip_list = zip_file.namelist()
+            # print("zip list namelist", zip_list)
+
+            if get_platform() == "Windows":
+                if os.path.exists(zip_list[0]):
+                    print("zip2text exists")
+
+            # 循环解压文件到指定目录
+            file_list = []
+            for f in zip_list:
+                file_list.append(zip_file.extract(f, path=zip_path))
+            # zip_file.extractall(path=zip_path)
+            zip_file.close()
+
+            # 获取文件名
+            # file_list = []
+            # for root, dirs, files in os.walk(zip_path, topdown=False):
+            #     for name in dirs:
+            #         file_list.append(os.path.join(root, name) + os.sep)
+            #     for name in files:
+            #         file_list.append(os.path.join(root, name))
+            #
+            # # if get_platform() == "Windows":
+            # #     print("file_list", file_list)
+            #
+            # # 过滤掉doc缓存文件
+            # temp_list = []
+            # for f in file_list:
+            #     if re.search("~\$", f):
+            #         continue
+            #     else:
+            #         temp_list.append(f)
+            # file_list = temp_list
+
+        except Exception as e:
+            logging.info("zip format error!")
+            print("zip format error!", traceback.print_exc())
+            return [-3]
+
+        # 内部文件重命名
+        # file_list = inner_file_rename(file_list)
+        file_list = rename_inner_files(zip_path)
+        if judge_error_code(file_list):
+            return file_list
+
+        if get_platform() == "Windows":
+            print("============= zip file list")
+            # print(file_list)
+
+        text = []
+        for file in file_list:
+            if os.path.isdir(file):
+                continue
+
+            # 无文件后缀,猜格式
+            if len(file.split(".")) <= 1:
+                logging.info(str(file) + " has no type! Guess type...")
+                _type = judge_format(file)
+                if _type is None:
+                    logging.info(str(file) + "cannot guess type!")
+                    sub_text = [""]
+                else:
+                    logging.info(str(file) + " guess type: " + _type)
+                    new_file = str(file) + "." + _type
+                    os.rename(file, new_file)
+                    file = new_file
+                    sub_text = getText(_type, file)
+            # 有文件后缀,截取
+            else:
+                _type = file.split(".")[-1]
+                sub_text = getText(_type, file)
+
+            if judge_error_code(sub_text, code=[-3]):
+                continue
+            if judge_error_code(sub_text):
+                return sub_text
+
+            text = text + sub_text
+        return text
+    except Exception as e:
+        logging.info("zip2text error!")
+        print("zip2text", traceback.print_exc())
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def rar2text(path, unique_type_dir):
+    logging.info("into rar2text")
+    try:
+        rar_path = unique_type_dir
+
+        try:
+            # shell调用unrar解压
+            _signal = os.system("unrar x " + path + " " + rar_path)
+            print("rar2text _signal", _signal)
+            # =0, 解压成功
+            if _signal != 0:
+                raise Exception
+        except Exception as e:
+            logging.info("rar format error!")
+            print("rar format error!", e)
+            return [-3]
+
+        # 获取文件名
+        # file_list = []
+        # for root, dirs, files in os.walk(rar_path, topdown=False):
+        #     for name in dirs:
+        #         file_list.append(os.path.join(root, name) + os.sep)
+        #     for name in files:
+        #         file_list.append(os.path.join(root, name))
+
+        if get_platform() == "Windows":
+            print("============= rar file list")
+
+        # 内部文件重命名
+        # file_list = inner_file_rename(file_list)
+        file_list = rename_inner_files(rar_path)
+        if judge_error_code(file_list):
+            return file_list
+
+        text = []
+        for file in file_list:
+            if os.path.isdir(file):
+                continue
+
+            # 无文件后缀,猜格式
+            if len(file.split(".")) <= 1:
+                logging.info(str(file) + " has no type! Guess type...")
+                _type = judge_format(file)
+                if _type is None:
+                    logging.info(str(file) + "cannot guess type!")
+                    sub_text = [""]
+                else:
+                    logging.info(str(file) + " guess type: " + _type)
+                    new_file = str(file) + "." + _type
+                    os.rename(file, new_file)
+                    file = new_file
+                    sub_text = getText(_type, file)
+            # 有文件后缀,截取
+            else:
+                _type = file.split(".")[-1]
+                sub_text = getText(_type, file)
+
+            if judge_error_code(sub_text, code=[-3]):
+                continue
+            if judge_error_code(sub_text):
+                return sub_text
+
+            # print("sub text", sub_text, file, _type)
+            text = text + sub_text
+        return text
+    except Exception as e:
+        logging.info("rar2text error!")
+        print("rar2text", traceback.print_exc())
+        return [-1]
+
+
+def inner_file_rename(path_list):
+    logging.info("into inner_file_rename")
+    try:
+        # 先过滤文件名中的点 '.'
+        path_list.sort(key=lambda x: len(x), reverse=True)
+        for i in range(len(path_list)):
+            old_path = path_list[i]
+            # 对于目录,判断最后一级是否需过滤,重命名
+            if os.path.isdir(old_path):
+                ps = old_path.split(os.sep)
+                old_p = ps[-2]
+                if '.' in old_p:
+                    new_p = re.sub("\\.", "", old_p)
+                    new_path = ""
+                    for p in ps[:-2]:
+                        new_path += p + os.sep
+                    new_path += new_p + os.sep
+
+                    # 重命名,更新
+                    # print("has .", path_list[i], new_path)
+                    os.rename(old_path, new_path)
+                    for j in range(len(path_list)):
+                        if old_path in path_list[j]:
+                            path_list[j] = re.sub(old_p, new_p, path_list[j]) + os.sep
+
+        # 将path分割,按分割个数排名
+        path_len_list = []
+        for p in path_list:
+            p_ss = p.split(os.sep)
+            temp_p_ss = []
+            for pp in p_ss:
+                if pp == "":
+                    continue
+                temp_p_ss.append(pp)
+            p_ss = temp_p_ss
+            path_len_list.append([p, p_ss, len(p_ss)])
+
+        # 从路径分割少的开始改名,即从根目录开始改
+        path_len_list.sort(key=lambda x: x[2])
+
+        # for p in path_len_list:
+        #     print("---", p[1])
+
+        # 判断不用变的目录在第几级
+        no_change_level = 0
+        loop = 0
+        for p_s in path_len_list[0][1]:
+            if p_s[-4:] == "_rar" or p_s[-4:] == "_zip":
+                no_change_level += loop
+                loop = 0
+            loop += 1
+        no_change_level += 1
+
+        # 每个
+        new_path_list = []
+        for path_len in path_len_list:
+            # 前n个是固定路径
+            new_path = ""
+            for i in range(no_change_level):
+                new_path += path_len[1][i] + os.sep
+            old_path = new_path
+
+            if not get_platform() == "Windows":
+                old_path = os.sep + old_path
+                new_path = os.sep + new_path
+            # print("path_len[1][3:]", path_len[1][3:])
+
+            count = 0
+            for p in path_len[1][no_change_level:]:
+                # 新路径全部转换hash
+                new_path += str(hash(p))
+
+                # 最后一个不加os.sep,并且旧路径最后一个不转换hash
+                if count < len(path_len[1][no_change_level:]) - 1:
+                    old_path += str(hash(p)) + os.sep
+                    new_path += os.sep
+                else:
+                    old_path += p
+                count += 1
+
+            # path是文件夹再加os.sep
+            if os.path.isdir(path_len[0]):
+                new_path += os.sep
+                old_path += os.sep
+            # path是文件再加文件名后缀
+            else:
+                p_ss = path_len[1][-1].split(".")
+                if len(p_ss) > 1:
+                    path_suffix = "." + p_ss[-1]
+                    new_path += path_suffix
+
+            print("inner_file_rename", old_path, "to", new_path)
+            os.rename(old_path, new_path)
+            new_path_list.append(new_path)
+
+        return new_path_list
+    except Exception as e:
+        logging.info("inner_file_rename error!")
+        print("inner_file_rename", traceback.print_exc())
+        return [-1]
+
+
+def rename_inner_files(root_path):
+    try:
+        logging.info("into rename_inner_files")
+        # 获取解压文件夹下所有文件+文件夹,不带根路径
+        path_list = []
+        for root, dirs, files in os.walk(root_path, topdown=False):
+            for name in dirs:
+                p = os.path.join(root, name) + os.sep
+                p = re.sub(root_path, "", p)
+                path_list.append(p)
+            for name in files:
+                p = os.path.join(root, name)
+                p = re.sub(root_path, "", p)
+                path_list.append(p)
+
+        # 按路径长度排序
+        path_list.sort(key=lambda x: len(x), reverse=True)
+
+        # 循环改名
+        for old_path in path_list:
+            # 按路径分隔符分割
+            ss = old_path.split(os.sep)
+            # 判断是否文件夹
+            is_dir = 0
+            file_type = ""
+            if os.path.isdir(root_path + old_path):
+                ss = ss[:-1]
+                is_dir = 1
+            else:
+                if "." in old_path:
+                    file_type = "." + old_path.split(".")[-1]
+                else:
+                    file_type = ""
+
+            # 最后一级需要用hash改名
+            new_path = ""
+            # new_path = re.sub(ss[-1], str(hash(ss[-1])), old_path) + file_type
+            current_level = 0
+            for s in ss:
+                # 路径拼接
+                if current_level < len(ss) - 1:
+                    new_path += s + os.sep
+                else:
+                    new_path += str(hash(s)) + file_type
+                current_level += 1
+
+            new_ab_path = root_path + new_path
+            old_ab_path = root_path + old_path
+            os.rename(old_ab_path, new_ab_path)
+
+        # 重新获取解压文件夹下所有文件+文件夹
+        new_path_list = []
+        for root, dirs, files in os.walk(root_path, topdown=False):
+            for name in dirs:
+                new_path_list.append(os.path.join(root, name) + os.sep)
+            for name in files:
+                new_path_list.append(os.path.join(root, name))
+        # print("new_path_list", new_path_list)
+        return new_path_list
+    except:
+        traceback.print_exc()
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def xls2text(path, unique_type_dir):
+    logging.info("into xls2text")
+    try:
+        # 调用libreoffice格式转换
+        file_path = from_office_interface(path, unique_type_dir, 'xlsx')
+        # if file_path == [-3]:
+        #     return [-3]
+        if judge_error_code(file_path):
+            return file_path
+
+        text = xlsx2text(file_path, unique_type_dir)
+        # if text == [-1]:
+        #     return [-1]
+        # if text == [-3]:
+        #     return [-3]
+        if judge_error_code(text):
+            return text
+
+        return text
+    except Exception as e:
+        logging.info("xls2text error!")
+        print("xls2text", traceback.print_exc())
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def xlsx2text(path, unique_type_dir):
+    logging.info("into xlsx2text")
+    try:
+        try:
+            # sheet_name=None, 即拿取所有sheet,存为dict
+            df_dict = pandas.read_excel(path, header=None, keep_default_na=False, sheet_name=None)
+        except Exception as e:
+            logging.info("xlsx format error!")
+            # print("xlsx format error!", e)
+            return [-3]
+
+        df_list = [sheet for sheet in df_dict.values()]
+        sheet_text = ""
+        for df in df_list:
+            text = '<table border="1">' + "\n"
+            for index, row in df.iterrows():
+                text = text + "<tr>"
+                for r in row:
+                    text = text + "<td>" + str(r) + "</td>" + "\n"
+                    # print(text)
+                text = text + "</tr>" + "\n"
+            text = text + "</table>" + "\n"
+            sheet_text += text
+
+        return [sheet_text]
+    except Exception as e:
+        logging.info("xlsx2text error!")
+        print("xlsx2text", traceback.print_exc())
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def swf2text(path, unique_type_dir):
+    logging.info("into swf2text")
+    try:
+        try:
+            with open(path, 'rb') as f:
+                swf_file = SWF(f)
+                svg_exporter = SVGExporter()
+                svg = swf_file.export(svg_exporter)
+            # with open('swf_export.jpg', 'wb') as f:
+            #     f.write(svg.read())
+            swf_str = str(svg.getvalue(), encoding='utf-8')
+        except Exception as e:
+            logging.info("swf format error!")
+            traceback.print_exc()
+            return [-3]
+
+        # 正则匹配图片的信息位置
+        result0 = re.finditer('<image id=(.[^>]*)', swf_str)
+        image_bytes_list = []
+        i = 0
+        image_path_prefix = path.split(".")[-2] + "_" + path.split(".")[-1]
+        image_path_list = []
+        for r in result0:
+            # 截取图片信息所在位置
+            swf_str0 = swf_str[r.span()[0]:r.span()[1] + 1]
+
+            # 正则匹配得到图片的base64编码
+            result1 = re.search('xlink:href="data:(.[^>]*)', swf_str0)
+            swf_str1 = swf_str0[result1.span()[0]:result1.span()[1]]
+            reg1_prefix = 'b\''
+            result1 = re.search(reg1_prefix + '(.[^\']*)', swf_str1)
+            swf_str1 = swf_str1[result1.span()[0] + len(reg1_prefix):result1.span()[1]]
+
+            # base64_str -> base64_bytes -> no "\\" base64_bytes -> bytes -> image
+            base64_bytes_with_double = bytes(swf_str1, "utf-8")
+            base64_bytes = codecs.escape_decode(base64_bytes_with_double, "hex-escape")[0]
+            image_bytes = base64.b64decode(base64_bytes)
+            image_bytes_list.append(image_bytes)
+            image_path = image_path_prefix + "_page_" + str(i) + ".png"
+            with open(image_path, 'wb') as f:
+                f.write(image_bytes)
+
+            image_path_list.append(image_path)
+            # 正则匹配得到图片的宽高
+            # reg2_prefix = 'width="'
+            # result2 = re.search(reg2_prefix + '(\d+)', swf_str0)
+            # swf_str2 = swf_str0[result2.span()[0]+len(reg2_prefix):result2.span()[1]]
+            # width = swf_str2
+            # reg2_prefix = 'height="'
+            # result2 = re.search(reg2_prefix + '(\d+)', swf_str0)
+            # swf_str2 = swf_str0[result2.span()[0]+len(reg2_prefix):result2.span()[1]]
+            # height = swf_str2
+            i += 1
+
+        text_list = []
+        # print("image_path_list", image_path_list)
+        for image_path in image_path_list:
+            text = picture2text(image_path)
+            # print("text", text)
+
+            if judge_error_code(text, code=[-3]):
+                continue
+            if judge_error_code(text):
+                return text
+
+            text = text[0]
+            text_list.append(text)
+
+        text = ""
+        for t in text_list:
+            text += t
+
+        return [text]
+    except Exception as e:
+        logging.info("swf2text error!")
+        print("swf2text", traceback.print_exc())
+        return [-1]
+
+
+@get_memory_info.memory_decorator
+def picture2text(path, html=False):
+    logging.info("into picture2text")
+    try:
+        # 判断图片中表格
+        img = cv2.imread(path)
+        if img is None:
+            return [-3]
+
+        # if get_platform() == "Windows":
+        #     print("picture2text img", img)
+
+        text, column_list, outline_points, is_table = image_preprocess(img, path)
+        if judge_error_code(text):
+            return text
+        # if text == [-5]:
+        #     return [-5]
+        # if text == [-2]:
+        #     return [-2]
+        # if text == [-1]:
+        #     return [-1]
+
+        if html:
+            text = add_div(text)
+        return [text]
+    except Exception as e:
+        logging.info("picture2text error!")
+        print("picture2text", traceback.print_exc())
+        return [-1]
+
+
+port_num = [0]
+def choose_port():
+    process_num = 4
+    if port_num[0] % process_num == 0:
+        _url = local_url + ":15011"
+    elif port_num[0] % process_num == 1:
+        _url = local_url + ":15012"
+    elif port_num[0] % process_num == 2:
+        _url = local_url + ":15013"
+    elif port_num[0] % process_num == 3:
+        _url = local_url + ":15014"
+
+    port_num[0] = port_num[0] + 1
+    return _url
+
+
+@get_memory_info.memory_decorator
+def from_ocr_interface(image_stream, is_table=False):
+    logging.info("into from_ocr_interface")
+    try:
+        base64_stream = base64.b64encode(image_stream)
+
+        # 调用接口
+        try:
+            r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"))
+        except TimeoutError:
+            if is_table:
+                return [-5], [-5]
+            else:
+                return [-5]
+        except requests.exceptions.ConnectionError as e:
+            if is_table:
+                return [-2], [-2]
+            else:
+                return [-2]
+
+        _dict = r
+        text_list = eval(_dict.get("text"))
+        bbox_list = eval(_dict.get("bbox"))
+        if text_list is None:
+            text_list = []
+        if bbox_list is None:
+            bbox_list = []
+
+        if is_table:
+            return text_list, bbox_list
+        else:
+            if text_list and bbox_list:
+                text = get_sequential_data(text_list, bbox_list, html=True)
+                if judge_error_code(text):
+                    return text
+                # if text == [-1]:
+                #     return [-1]
+            else:
+                text = ""
+            return text
+    except Exception as e:
+        logging.info("from_ocr_interface error!")
+        # print("from_ocr_interface", e, global_type)
+        if is_table:
+            return [-1], [-1]
+        else:
+            return [-1]
+
+
+@get_memory_info.memory_decorator
+def from_otr_interface(image_stream):
+    logging.info("into from_otr_interface")
+    try:
+        base64_stream = base64.b64encode(image_stream)
+
+        # 调用接口
+        try:
+            r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"))
+        except TimeoutError:
+            return [-5], [-5], [-5], [-5]
+        except requests.exceptions.ConnectionError as e:
+            logging.info("from_otr_interface")
+            print("from_otr_interface", traceback.print_exc())
+            return [-2], [-2], [-2], [-2]
+
+        # 处理结果
+        _dict = r
+        points = eval(_dict.get("points"))
+        split_lines = eval(_dict.get("split_lines"))
+        bboxes = eval(_dict.get("bboxes"))
+        outline_points = eval(_dict.get("outline_points"))
+        # print("from_otr_interface len(bboxes)", len(bboxes))
+        if points is None:
+            points = []
+        if split_lines is None:
+            split_lines = []
+        if bboxes is None:
+            bboxes = []
+        if outline_points is None:
+            outline_points = []
+        return points, split_lines, bboxes, outline_points
+    except Exception as e:
+        logging.info("from_otr_interface error!")
+        print("from_otr_interface", traceback.print_exc())
+        return [-1], [-1], [-1], [-1]
+
+
+def from_office_interface(src_path, dest_path, target_format, retry_times=1):
+    try:
+        # Win10跳出超时装饰器
+        if get_platform() == "Windows":
+            # origin_office_convert = office_convert.__wrapped__
+            # file_path = origin_office_convert(src_path, dest_path, target_format, retry_times)
+            file_path = office_convert(src_path, dest_path, target_format, retry_times)
+        else:
+            # 将装饰器包装为一个类,否则多进程Pickle会报错 it's not the same object as xxx 问题,
+            # timeout_decorator_obj = my_timeout_decorator.TimeoutClass(office_convert, 180, TimeoutError)
+            # file_path = timeout_decorator_obj.run(src_path, dest_path, target_format, retry_times)
+
+            file_path = office_convert(src_path, dest_path, target_format, retry_times)
+
+        if judge_error_code(file_path):
+            return file_path
+        return file_path
+    except TimeoutError:
+        logging.info("from_office_interface timeout error!")
+        return [-5]
+    except:
+        logging.info("from_office_interface error!")
+        print("from_office_interface", traceback.print_exc())
+        return [-1]
+
+
+def get_sequential_data(text_list, bbox_list, html=False):
+    logging.info("into get_sequential_data")
+    try:
+        text = ""
+        order_list = []
+        for i in range(len(text_list)):
+            length_start = bbox_list[i][0][0]
+            length_end = bbox_list[i][1][0]
+            height_start = bbox_list[i][0][1]
+            height_end = bbox_list[i][-1][1]
+            # print([length_start, length_end, height_start, height_end])
+            order_list.append([text_list[i], length_start, length_end, height_start, height_end])
+            # text = text + infomation['text'] + "\n"
+
+        if get_platform() == "Windows":
+            print("get_sequential_data", order_list)
+        if not order_list:
+            if get_platform() == "Windows":
+                print("get_sequential_data", "no order list")
+            return ""
+
+        # 根据bbox的坐标对输出排序
+        order_list.sort(key=lambda x: (x[3], x[1]))
+
+        # 根据bbox分行分列
+        # col_list = []
+        # height_end = int((order_list[0][4] + order_list[0][3]) / 2)
+        # for i in range(len(order_list)):
+        #     if height_end - threshold <= order_list[i][3] <= height_end + threshold:
+        #         col_list.append(order_list[i])
+        #     else:
+        #         row_list.append(col_list)
+        #         col_list = []
+        #         height_end = int((order_list[i][4] + order_list[i][3]) / 2)
+        #         col_list.append(order_list[i])
+        #     if i == len(order_list) - 1:
+        #         row_list.append(col_list)
+
+        row_list = []
+        used_box = []
+        threshold = 5
+        for box in order_list:
+            if box in used_box:
+                continue
+
+            height_center = (box[4] + box[3]) / 2
+            row = []
+            for box2 in order_list:
+                if box2 in used_box:
+                    continue
+                height_center2 = (box2[4] + box2[3]) / 2
+                if height_center - threshold <= height_center2 <= height_center + threshold:
+                    if box2 not in row:
+                        row.append(box2)
+                        used_box.append(box2)
+            row.sort(key=lambda x: x[0])
+            row_list.append(row)
+
+        for row in row_list:
+            if not row:
+                continue
+            if len(row) <= 1:
+                text = text + row[0][0] + "\n"
+            else:
+                sub_text = ""
+                row.sort(key=lambda x: x[1])
+                for col in row:
+                    sub_text = sub_text + col[0] + " "
+                sub_text = sub_text + "\n"
+                text += sub_text
+
+        if html:
+            text = "<div>" + text
+            text = re.sub("\n", "</div>\n<div>", text)
+            text += "</div>"
+            # if text[-5:] == "<div>":
+            #     text = text[:-5]
+        return text
+
+    except Exception as e:
+        logging.info("get_sequential_data error!")
+        print("get_sequential_data", traceback.print_exc())
+        return [-1]
+
+
+def get_formatted_table(text_list, text_bbox_list, table_bbox_list, split_line):
+    logging.info("into get_formatted_table")
+    try:
+        # 重新定义text_bbox_list,[point, point, text]
+        text_bbox_list = [[text_bbox_list[i][0], text_bbox_list[i][2], text_list[i]] for i in
+                          range(len(text_bbox_list))]
+        # 按纵坐标排序
+        text_bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
+        table_bbox_list.sort(key=lambda x: (x[0][1], x[0][0]))
+
+        # print("text_bbox_list", text_bbox_list)
+        # print("table_bbox_list", table_bbox_list)
+
+        # bbox位置 threshold
+        threshold = 5
+
+        # 根据split_line分区,可能有个区多个表格 [(), ()]
+        area_text_bbox_list = []
+        area_table_bbox_list = []
+        # print("get_formatted_table, split_line", split_line)
+        for j in range(1, len(split_line)):
+            last_y = split_line[j - 1][0][1]
+            current_y = split_line[j][0][1]
+            temp_text_bbox_list = []
+            temp_table_bbox_list = []
+
+            # 找出该区域下text bbox
+            for text_bbox in text_bbox_list:
+                # 计算 text bbox 中心点
+                text_bbox_center = ((text_bbox[1][0] + text_bbox[0][0]) / 2,
+                                    (text_bbox[1][1] + text_bbox[0][1]) / 2)
+                if last_y - threshold <= text_bbox_center[1] <= current_y + threshold:
+                    temp_text_bbox_list.append(text_bbox)
+            area_text_bbox_list.append(temp_text_bbox_list)
+
+            # 找出该区域下table bbox
+            for table_bbox in table_bbox_list:
+                # 计算 table bbox 中心点
+                table_bbox_center = ((table_bbox[1][0] + table_bbox[0][0]) / 2,
+                                     (table_bbox[1][1] + table_bbox[0][1]) / 2)
+                if last_y < table_bbox_center[1] < current_y:
+                    temp_table_bbox_list.append(table_bbox)
+            area_table_bbox_list.append(temp_table_bbox_list)
+
+        # 对每个区域分别进行两个bbox匹配,生成表格
+        area_text_list = []
+        area_column_list = []
+        for j in range(len(area_text_bbox_list)):
+            # 每个区域的table bbox 和text bbox
+            temp_table_bbox_list = area_table_bbox_list[j]
+            temp_text_bbox_list = area_text_bbox_list[j]
+
+            # 判断该区域有无表格bbox
+            # 若无表格,将该区域文字连接
+            if not temp_table_bbox_list:
+                # 找出该区域的所有text bbox
+                only_text_list = []
+                only_bbox_list = []
+                for text_bbox in temp_text_bbox_list:
+                    only_text_list.append(text_bbox[2])
+                    only_bbox_list.append([text_bbox[0], text_bbox[1]])
+                only_text = get_sequential_data(only_text_list, only_bbox_list, True)
+                if only_text == [-1]:
+                    return [-1], [-1]
+                area_text_list.append(only_text)
+                area_column_list.append(0)
+                continue
+
+            # 有表格
+            # 文本对应的表格格子
+            text_in_table = {}
+            for i in range(len(temp_text_bbox_list)):
+                text_bbox = temp_text_bbox_list[i]
+
+                # 计算 text bbox 中心点
+                text_bbox_center = ((text_bbox[1][0] + text_bbox[0][0]) / 2,
+                                    (text_bbox[1][1] + text_bbox[0][1]) / 2)
+
+                # 判断中心点在哪个table bbox中
+                for table_bbox in temp_table_bbox_list:
+                    # 中心点在table bbox中,将text写入字典
+                    if table_bbox[0][0] <= text_bbox_center[0] <= table_bbox[1][0] and \
+                            table_bbox[0][1] <= text_bbox_center[1] <= table_bbox[1][1]:
+                        if str(table_bbox) in text_in_table.keys():
+                            text_in_table[str(table_bbox)] = text_in_table.get(str(table_bbox)) + text_bbox[2]
+                        else:
+                            text_in_table[str(table_bbox)] = text_bbox[2]
+                        break
+
+                    # 如果未找到text bbox匹配的table bbox,加大threshold匹配
+                    # elif (table_bbox[0][0] <= text_bbox_center[0]+threshold <= table_bbox[1][0] and
+                    #         table_bbox[0][1] <= text_bbox_center[1]+threshold <= table_bbox[1][1]) or \
+                    #         (table_bbox[0][0] <= text_bbox_center[0]-threshold <= table_bbox[1][0] and
+                    #          table_bbox[0][1] <= text_bbox_center[1]-threshold <= table_bbox[1][1]) or \
+                    #         (table_bbox[0][0] <= text_bbox_center[0]+threshold <= table_bbox[1][0] and
+                    #          table_bbox[0][1] <= text_bbox_center[1]-threshold <= table_bbox[1][1]) or \
+                    #         (table_bbox[0][0] <= text_bbox_center[0]-threshold <= table_bbox[1][0] and
+                    #          table_bbox[0][1] <= text_bbox_center[1]+threshold <= table_bbox[1][1]):
+                    #     if str(table_bbox) in text_in_table.keys():
+                    #         text_in_table[str(table_bbox)] = text_in_table.get(str(table_bbox)) + text_bbox[2]
+                    #     else:
+                    #         text_in_table[str(table_bbox)] = text_bbox[2]
+                    #     break
+
+            # 对表格格子进行分行分列,并计算总计多少小列
+            # 放入坐标
+            all_col_list = []
+            all_row_list = []
+            for i in range(len(temp_table_bbox_list)):
+                table_bbox = temp_table_bbox_list[i]
+
+                # 放入所有坐标x
+                if table_bbox[0][0] not in all_col_list:
+                    all_col_list.append(table_bbox[0][0])
+                if table_bbox[1][0] not in all_col_list:
+                    all_col_list.append(table_bbox[1][0])
+
+                # 放入所有坐标y
+                if table_bbox[0][1] not in all_row_list:
+                    all_row_list.append(table_bbox[0][1])
+                if table_bbox[1][1] not in all_row_list:
+                    all_row_list.append(table_bbox[1][1])
+            all_col_list.sort(key=lambda x: x)
+            all_row_list.sort(key=lambda x: x)
+
+            # 分行
+            row_list = []
+            rows = []
+            temp_table_bbox_list.sort(key=lambda x: (x[0][1], x[0][0], x[1][1], x[1][0]))
+            y_row = temp_table_bbox_list[0][0][1]
+            for i in range(len(temp_table_bbox_list)):
+                table_bbox = temp_table_bbox_list[i]
+
+                if y_row - threshold <= table_bbox[0][1] <= y_row + threshold:
+                    rows.append(table_bbox)
+                else:
+                    y_row = table_bbox[0][1]
+                    if rows:
+                        rows.sort(key=lambda x: x[0][0])
+                        row_list.append(rows)
+                    rows = []
+                    rows.append(table_bbox)
+                # print("*" * 30)
+                # print(row_list)
+
+                if i == len(temp_table_bbox_list) - 1:
+                    if rows:
+                        rows.sort(key=lambda x: x[0][0])
+                        row_list.append(rows)
+
+            # 生成表格,包括文字和格子宽度
+            area_column = []
+            text = '<table border="1">' + "\n"
+            for row in row_list:
+                text += "<tr>" + "\n"
+                for col in row:
+                    # 计算bbox y坐标之间有多少其他点,+1即为所占行数
+                    row_span = 1
+                    for y in all_row_list:
+                        if col[0][1] < y < col[1][1]:
+                            if y - col[0][1] >= 2 and col[1][1] - y >= 2:
+                                row_span += 1
+
+                    # 计算bbox x坐标之间有多少其他点,+1即为所占列数
+                    col_span = 1
+                    for x in all_col_list:
+                        if col[0][0] < x < col[1][0]:
+                            if x - col[0][0] >= 2 and col[1][0] - x >= 2:
+                                col_span += 1
+
+                    text += "<td colspan=" + str(col_span) + " rowspan=" + str(row_span) + ">"
+
+                    if str(col) in text_in_table.keys():
+                        text += text_in_table.get(str(col))
+                    else:
+                        text += ""
+                    text += "</td>" + "\n"
+                text += "</tr>" + "\n"
+            text += "</table>" + "\n"
+
+            # 计算最大column
+            max_col_num = 0
+            for row in row_list:
+                col_num = 0
+                for col in row:
+                    col_num += 1
+                if max_col_num < col_num:
+                    max_col_num = col_num
+
+            area_text_list.append(text)
+            area_column_list.append(max_col_num)
+
+        text = ""
+        if get_platform() == "Windows":
+            print("get_formatted_table area_text_list", area_text_list)
+        for area_text in area_text_list:
+            text += area_text
+        return text, area_column_list
+    except Exception as e:
+        logging.info("get_formatted_table error!")
+        print("get_formatted_table", traceback.print_exc())
+        return [-1], [-1]
+
+
+def getText(_type, path_or_stream):
+    print("file type - " + _type)
+    logging.info("file type - " + _type)
+
+    try:
+        ss = path_or_stream.split(".")
+        unique_type_dir = ss[-2] + "_" + ss[-1] + os.sep
+    except:
+        unique_type_dir = path_or_stream + "_" + _type + os.sep
+
+    if _type == "pdf":
+        return pdf2text(path_or_stream, unique_type_dir)
+    if _type == "docx":
+        return docx2text(path_or_stream, unique_type_dir)
+    if _type == "zip":
+        return zip2text(path_or_stream, unique_type_dir)
+    if _type == "rar":
+        return rar2text(path_or_stream, unique_type_dir)
+    if _type == "xlsx":
+        return xlsx2text(path_or_stream, unique_type_dir)
+    if _type == "xls":
+        return xls2text(path_or_stream, unique_type_dir)
+    if _type == "doc":
+        return doc2text(path_or_stream, unique_type_dir)
+    if _type == "jpg" or _type == "png" or _type == "jpeg":
+        return picture2text(path_or_stream)
+    if _type == "swf":
+        return swf2text(path_or_stream, unique_type_dir)
+    if _type == "txt":
+        return txt2text(path_or_stream)
+
+    return [""]
+
+
+def to_html(path, text):
+    with open(path, 'w') as f:
+        f.write("<!DOCTYPE HTML>")
+        f.write('<head><meta charset="UTF-8"></head>')
+        f.write("<body>")
+        f.write(text)
+        f.write("</body>")
+
+
+def resize_image(image_path, size):
+    try:
+        image_np = cv2.imread(image_path)
+        # print(image_np.shape)
+        width = image_np.shape[1]
+        height = image_np.shape[0]
+        h_w_rate = height / width
+
+        # width_standard = 900
+        # height_standard = 1400
+
+        width_standard = size[1]
+        height_standard = size[0]
+
+        width_new = int(height_standard / h_w_rate)
+        height_new = int(width_standard * h_w_rate)
+
+        if width > width_standard:
+            image_np = cv2.resize(image_np, (width_standard, height_new))
+        elif height > height_standard:
+            image_np = cv2.resize(image_np, (width_new, height_standard))
+
+        cv2.imwrite(image_path, image_np)
+        # print("resize_image", image_np.shape)
+        return
+    except Exception as e:
+        logging.info("resize_image")
+        print("resize_image", e, global_type)
+        return
+
+
+def remove_red_seal(image_np):
+    """
+    去除红色印章
+    """
+    # 获得红色通道
+    blue_c, green_c, red_c = cv2.split(image_np)
+
+    # 多传入一个参数cv2.THRESH_OTSU,并且把阈值thresh设为0,算法会找到最优阈值
+    thresh, ret = cv2.threshold(red_c, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+    # print("remove_red_seal thresh", thresh)
+
+    # 实测调整为95%效果好一些
+    filter_condition = int(thresh * 0.98)
+    thresh1, red_thresh = cv2.threshold(red_c, filter_condition, 255, cv2.THRESH_BINARY)
+
+    # 把图片转回 3 通道
+    image_and = np.expand_dims(red_thresh, axis=2)
+    image_and = np.concatenate((image_and, image_and, image_and), axis=-1)
+    # print(image_and.shape)
+
+    # 膨胀
+    gray = cv2.cvtColor(image_and, cv2.COLOR_RGB2GRAY)
+    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
+    erode = cv2.erode(gray, kernel)
+    cv2.imshow("erode", erode)
+    cv2.waitKey(0)
+
+    image_and = np.bitwise_and(cv2.bitwise_not(blue_c), cv2.bitwise_not(erode))
+    result_img = cv2.bitwise_not(image_and)
+
+    cv2.imshow("remove_red_seal", result_img)
+    cv2.waitKey(0)
+    return result_img
+
+
+def remove_underline(image_np):
+    """
+    去除文字下划线
+    """
+    # 灰度化
+    gray = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY)
+    # 二值化
+    binary = cv2.adaptiveThreshold(~gray, 255,
+                                   cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,
+                                   15, 10)
+
+    # Sobel
+    kernel_row = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], np.float32)
+    kernel_col = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], np.float32)
+
+    # binary = cv2.filter2D(binary, -1, kernel=kernel)
+    binary_row = cv2.filter2D(binary, -1, kernel=kernel_row)
+    binary_col = cv2.filter2D(binary, -1, kernel=kernel_col)
+    cv2.imshow("custom_blur_demo", binary)
+    cv2.waitKey(0)
+
+    rows, cols = binary.shape
+    # 识别横线
+    scale = 5
+    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols // scale, 1))
+    erodedcol = cv2.erode(binary_row, kernel, iterations=1)
+    cv2.imshow("Eroded Image", erodedcol)
+    cv2.waitKey(0)
+    dilatedcol = cv2.dilate(erodedcol, kernel, iterations=1)
+    cv2.imshow("dilate Image", dilatedcol)
+    cv2.waitKey(0)
+    return
+
+
+def getMDFFromFile(path):
+    _length = 0
+    try:
+        _md5 = hashlib.md5()
+        with open(path, "rb") as ff:
+            while True:
+                data = ff.read(4096)
+                if not data:
+                    break
+                _length += len(data)
+                _md5.update(data)
+        return _md5.hexdigest(), _length
+    except Exception as e:
+        traceback.print_exc()
+        return None, _length
+
+
+def add_html_format(text_list):
+    new_text_list = []
+    for t in text_list:
+        html_t = "<!DOCTYPE HTML>\n"
+        html_t += '<head><meta charset="UTF-8"></head>\n'
+        html_t += "<body>\n"
+        html_t += t
+        html_t += "\n</body>\n"
+        new_text_list.append(html_t)
+    return new_text_list
+
+
+@timeout_decorator.timeout(1200, timeout_exception=TimeoutError)
+def unique_temp_file_process(stream, _type):
+    logging.info("into unique_temp_file_process")
+    try:
+        # 每个调用在temp中创建一个唯一空间
+        uid1 = uuid.uuid1().hex
+        unique_space_path = _path + os.sep + "temp" + os.sep + uid1 + os.sep
+        # unique_space_path = "/mnt/fangjiasheng/" + "temp/" + uid1 + "/"
+        # 判断冲突
+        if not os.path.exists(unique_space_path):
+            if not os.path.exists(_path + os.sep + "temp"):
+                os.mkdir(_path + os.sep + "temp" + os.sep)
+            os.mkdir(unique_space_path)
+        else:
+            uid2 = uuid.uuid1().hex
+            if not os.path.exists(_path + os.sep + "temp"):
+                os.mkdir(_path + os.sep + "temp" + os.sep)
+            os.mkdir(_path + os.sep + "temp" + os.sep + uid2 + os.sep)
+            # os.mkdir("/mnt/" + "temp/" + uid2 + "/")
+        # 在唯一空间中,对传入的文件也保存为唯一
+        uid3 = uuid.uuid1().hex
+        file_path = unique_space_path + uid3 + "." + _type
+        with open(file_path, "wb") as ff:
+            ff.write(stream)
+
+        # 跳过一些编号
+        print("getMDFFromFile", getMDFFromFile(file_path))
+        if getMDFFromFile(file_path)[0] == '84dba5a65339f338d3ebdf9f33fae13e'\
+                or getMDFFromFile(file_path)[0] == '3d9f9f4354582d85b21b060ebd5786db'\
+                or getMDFFromFile(file_path)[0] == 'b52da40f24c6b29dfc2ebeaefe4e41f1' \
+                or getMDFFromFile(file_path)[0] == 'eefb925b7ccec1467be20b462fde2a09':
+            raise Exception
+
+        text = getText(_type, file_path)
+        return text
+    except Exception as e:
+        # print("Convert error! Delete temp file. ", e, global_type)
+        logging.info("unique_temp_file_process")
+        print("unique_temp_file_process:", traceback.print_exc())
+        return [-1]
+    finally:
+        print("======================================")
+        print("File md5:", getMDFFromFile(file_path))
+        try:
+            if get_platform() == "Linux":
+                # 删除该唯一空间下所有文件
+                if os.path.exists(unique_space_path):
+                    shutil.rmtree(unique_space_path)
+                print()
+        except Exception as e:
+            logging.info("Delete Files Failed!")
+            # print("Delete Files Failed!")
+            return [-1]
+        print("Finally")
+
+    # to_html(_path + "6.html", text[0])
+    # to_html(unique_space_path + "result.html", text[0])
+    # return text
+
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+def log(msg):
+    """
+    @summary:打印信息
+    """
+    logger.info(msg)
+
+
+def cut_str(text_list, only_text_list, max_bytes_length=2000000):
+    logging.info("into cut_str")
+    try:
+
+        # 计算有格式总字节数
+        bytes_length = 0
+        for text in text_list:
+            bytes_length += len(bytes(text, encoding='utf-8'))
+
+        print("text_list", bytes_length)
+
+        # 小于直接返回
+        if bytes_length < max_bytes_length:
+            print("return text_list no cut")
+            return text_list
+
+        # 全部文件连接,重新计算无格式字节数
+        all_text = ""
+        bytes_length = 0
+        for text in only_text_list:
+            bytes_length += len(bytes(text, encoding='utf-8'))
+            all_text += text
+
+        print("only_text_list", bytes_length)
+        # 小于直接返回
+        if bytes_length < max_bytes_length:
+            print("return only_text_list no cut")
+            return only_text_list
+
+        # 截取字符
+        all_text = all_text[:int(max_bytes_length/3)]
+
+        print("text bytes ", len(bytes(all_text, encoding='utf-8')))
+        print("return only_text_list has cut")
+        return [all_text]
+    except Exception as e:
+        logging.info("cut_str " + str(e))
+        return ["-1"]
+
+
+@get_memory_info.memory_decorator
+def convert(data, ocr_model, otr_model):
+    """
+    接口返回值:
+    {[str], 1}: 处理成功
+    {[-1], 0}: 逻辑处理错误
+    {[-2], 0}: 接口调用错误
+    {[-3], 1}: 文件格式错误,无法打开
+    {[-4], 0}: 各类文件调用第三方包读取超时
+    {[-5], 0}: 整个转换过程超时
+    {[-6], 0}: 阿里云UDF队列超时
+    {[-7], 1}: 文件需密码,无法打开
+    :return: {"result": [], "is_success": int}
+    """
+
+    # 控制内存
+    # soft, hard = resource.getrlimit(resource.RLIMIT_AS)
+    # resource.setrlimit(resource.RLIMIT_AS, (15 * 1024 ** 3, hard))
+
+    logging.info("into convert")
+    start_time = time.time()
+    try:
+        # 模型加入全局变量
+        globals().update({"global_ocr_model": ocr_model})
+        globals().update({"global_otr_model": otr_model})
+
+        stream = base64.b64decode(data.get("file"))
+        _type = data.get("type")
+
+        if get_platform() == "Windows":
+            # 解除超时装饰器,直接访问原函数
+            origin_unique_temp_file_process = unique_temp_file_process.__wrapped__
+            text = origin_unique_temp_file_process(stream, _type)
+        else:
+            # Linux 通过装饰器设置整个转换超时时间
+            try:
+                text = unique_temp_file_process(stream, _type)
+            except TimeoutError:
+                logging.info("convert time out! 1200 sec")
+                text = [-5]
+
+        if text == [-1]:
+            print({"failed result": [-1], "is_success": 0}, time.time() - start_time)
+            return {"result_html": ["-1"], "result_text": ["-1"], "is_success": 0}
+        if text == [-2]:
+            print({"failed result": [-2], "is_success": 0}, time.time() - start_time)
+            return {"result_html": ["-2"], "result_text": ["-2"], "is_success": 0}
+        if text == [-3]:
+            print({"failed result": [-3], "is_success": 1}, time.time() - start_time)
+            return {"result_html": ["-3"], "result_text": ["-3"], "is_success": 1}
+        if text == [-4]:
+            print({"failed result": [-4], "is_success": 0}, time.time() - start_time)
+            return {"result_html": ["-4"], "result_text": ["-4"], "is_success": 0}
+        if text == [-5]:
+            print({"failed result": [-5], "is_success": 0}, time.time() - start_time)
+            return {"result_html": ["-5"], "result_text": ["-5"], "is_success": 0}
+        if text == [-7]:
+            print({"failed result": [-7], "is_success": 1}, time.time() - start_time)
+            return {"result_html": ["-7"], "result_text": ["-7"], "is_success": 1}
+
+        # text = add_html_format(text)
+
+        # 结果保存result.html
+        if get_platform() == "Windows":
+            text_str = ""
+            for t in text:
+                text_str += t
+            to_html("../result.html", text_str)
+
+        # 取纯文本
+        only_text = []
+        for t in text:
+            new_t = BeautifulSoup(t, "lxml").get_text()
+            new_t = re.sub("\n", "", new_t)
+            only_text.append(new_t)
+
+        # 判断长度,过长截取
+        text = cut_str(text, only_text)
+        only_text = cut_str(only_text, only_text)
+
+        if len(only_text) == 0:
+            only_text = [""]
+
+        if only_text[0] == '' and len(only_text) <= 1:
+            print({"finished result": ["", 0], "is_success": 1}, time.time() - start_time)
+        else:
+            print({"finished result": [str(only_text)[:20], len(str(text))],
+                   "is_success": 1}, time.time() - start_time)
+        return {"result_html": text, "result_text": only_text, "is_success": 1}
+    except Exception as e:
+        print({"failed result": [-1], "is_success": 0}, time.time() - start_time)
+        print("convert", traceback.print_exc())
+        return {"result_html": ["-1"], "result_text": ["-1"], "is_success": 0}
+
+
+global_type = ""
+local_url = "http://127.0.0.1"
+if get_platform() == "Windows":
+    _path = os.path.abspath(os.path.dirname(__file__))
+else:
+    _path = "/home/admin"
+    if not os.path.exists(_path):
+        _path = os.path.dirname(os.path.abspath(__file__))
+if __name__ == '__main__':
+
+    print(os.path.abspath(__file__) + "/../../")
+    # if len(sys.argv) == 2:
+    #     port = int(sys.argv[1])
+    # else:
+    #     port = 15015
+    # app.run(host='0.0.0.0', port=port, threaded=True, debug=False)
+    # log("format_conversion running")
+
+    # convert("", "ocr_model", "otr_model")
+    # _str = "啊"
+    # str1 = ""
+    # str2 = ""
+    # for i in range(900000):
+    #     str1 += _str
+    # list1 = [str1]
+    # for i in range(700000):
+    #     str2 += _str
+    # list2 = [str2]
+    # cut_str(list1, list2)
+
+    # file_path = "C:/Users/Administrator/Desktop/error1.png"
+    # file_path = "D:/Project/table-detect-master/train_data/label_1.jpg"
+    # file_path = "D:/Project/table-detect-master/test_files/1.png"
+    # file_path = "D:/Project/table-detect-master/test_files/table2.jpg"
+
+    file_path = "C:/Users/Administrator/Desktop/error9.pdf"
+    # file_path = "C:/Users/Administrator/Desktop/Test_Interface/test1.pdf"
+    # file_path = "C:/Users/Administrator/Desktop/Test_ODPS/1624875783055.pdf"
+
+    # file_path = "table2.jpg"
+
+    with open(file_path, "rb") as f:
+        file_bytes = f.read()
+    file_base64 = base64.b64encode(file_bytes)
+
+    data = {"file": file_base64, "type": file_path.split(".")[-1], "filemd5": 100}
+    ocr_model = ocr_interface.OcrModels().get_model()
+    otr_model = otr_interface.OtrModels().get_model()
+
+    result = convert(data, ocr_model, otr_model)
+    print("*"*40)
+    result = convert(data, ocr_model, otr_model)
+    # print(result)

+ 38 - 0
format_convert/get_memory_info.py

@@ -0,0 +1,38 @@
+import os
+import time
+from functools import wraps
+import logging
+
+import psutil
+
+from format_convert.judge_platform import get_platform
+if get_platform() == "Linux":
+    import resource
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+
+def memory_decorator(func):
+    @wraps(func)
+    def get_memory_info(*args, **kwargs):
+        if get_platform() == "Windows":
+            return func(*args, **kwargs)
+
+        # 只有linux有resource包
+        # usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+        usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
+        start_time = time.time()
+        logging.info("----- memory info start - " + func.__name__
+                     + " - " + str(round(usage, 2)) + " GB"
+                     + " - " + str(round(time.time()-start_time, 2)) + " sec")
+
+        result = func(*args, **kwargs)
+
+        # usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+        usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
+        logging.info("----- memory info end - " + func.__name__
+                     + " - " + str(round(usage, 2)) + " GB"
+                     + " - " + str(round(time.time()-start_time, 2)) + " sec")
+        return result
+
+    return get_memory_info

+ 6 - 0
format_convert/judge_platform.py

@@ -0,0 +1,6 @@
+import platform
+
+
+def get_platform():
+    sys = platform.system()
+    return sys

+ 117 - 0
format_convert/libreoffice_interface.py

@@ -0,0 +1,117 @@
+import os
+import re
+import signal
+import subprocess
+import sys
+import time
+import traceback
+import psutil
+from format_convert import timeout_decorator
+
+from format_convert import get_memory_info
+from format_convert.judge_platform import get_platform
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+
+def monitor_libreoffice():
+    try:
+        # logging.info("=========================================")
+        logging.info("into monitor_libreoffice")
+        # logging.info("------------------------------MEM top 10")
+        os.system("ps aux|head -1;ps aux|grep -v PID|sort -rn -k +4|head")
+        #
+        # logging.info("--------------------------soffice process")
+        # os.system("ps -ef | grep soffice")
+
+        pids = psutil.pids()
+        for pid in pids:
+            try:
+                process = psutil.Process(pid)
+                # if process.username() == "appuser":
+                if re.search("soffice|unrar", process.exe()):
+                    # if time.time() - process.create_time() >= 120:
+
+                    # logging.info("---------------------------killed soffice")
+                    # print("process", pid, process.exe())
+                    logging.info("process " + str(pid) + str(process.exe()))
+                    comm = "kill -9 " + str(pid)
+                    # subprocess.call(comm, shell=True)
+                    os.system(comm)
+                    # print("killed", pid)
+                    logging.info("killed " + str(pid))
+
+            except TimeoutError:
+                raise TimeoutError
+            except:
+                continue
+        # logging.info("=========================================")
+    except TimeoutError:
+        raise TimeoutError
+
+
+# @timeout_decorator.timeout(120, timeout_exception=TimeoutError, use_signals=False)
+def office_convert(src_path, dest_path, target_format, retry_times=1):
+    try:
+        logging.info("into office_convert")
+        print("src_path", src_path)
+        uid1 = src_path.split(os.sep)[-1].split(".")[0]
+        dest_file_path = dest_path + uid1 + "." + target_format
+        src_format = src_path.split(".")[-1]
+
+        # 重试转换
+        for i in range(retry_times):
+            # 调用Win下的libreoffice子进程
+            if get_platform() == "Windows":
+                soffice = 'C:\\Program Files\\LibreOfficeDev 5\\program\\soffice.exe'
+                comm_list = [soffice, '--headless', '--convert-to', target_format, src_path,
+                             '--outdir', dest_path+os.sep]
+
+                try:
+                    p = subprocess.call(comm_list, timeout=30*(i+2))
+
+                except:
+                    continue
+
+            # 调用Linux下的libreoffice子进程
+            else:
+                # 先杀libreoffice进程
+                monitor_libreoffice()
+
+                # 再调用转换
+                libreoffice_dir = 'soffice'
+                comm_list = [libreoffice_dir, '--headless', '--convert-to', target_format, src_path,
+                             '--outdir', dest_path+os.sep]
+
+                comm = ''
+                for c in comm_list:
+                    comm += c + ' '
+                # logging.info("office_convert command" + comm)
+                try:
+                    # p = subprocess.call(comm_list, timeout=30*(i+2))
+                    os.system(comm)
+                except TimeoutError:
+                    return [-5]
+                except Exception as e:
+                    print(src_format + ' to ' + target_format + ' Failed! Retry...', i, 'times')
+                    print(traceback.print_exc())
+                    continue
+
+            # 执行失败,重试
+            if not os.path.exists(dest_file_path):
+                print(src_format + ' to ' + target_format + ' Failed! Retry...', i, 'times')
+                continue
+            # 执行成功,跳出循环
+            else:
+                break
+
+        # 重试后还未成功
+        if not os.path.exists(dest_file_path):
+            # print(src_format + ' to ' + target_format + ' failed!')
+            logging.info(src_format + ' to ' + target_format + " failed!")
+            return [-3]
+
+        logging.info("out office_convert")
+        return dest_file_path
+    except TimeoutError:
+        return [-5]

+ 3 - 0
format_convert/swf/__init__.py

@@ -0,0 +1,3 @@
+__version__ = "1.5.4"
+
+from . import *

+ 188 - 0
format_convert/swf/actions.py

@@ -0,0 +1,188 @@
+
+class Action(object):
+    def __init__(self, code, length):
+        self._code = code
+        self._length = length
+    
+    @property
+    def code(self):
+        return self._code
+    
+    @property   
+    def length(self):
+        return self._length;
+    
+    @property
+    def version(self):
+        return 3
+        
+    def parse(self, data):
+        # Do nothing. Many Actions don't have a payload. 
+        # For the ones that have one we override this method.
+        pass
+    
+    def __repr__(self):
+        return "[Action] Code: 0x%x, Length: %d" % (self._code, self._length)
+
+class ActionUnknown(Action):
+    ''' Dummy class to read unknown actions '''
+    def __init__(self, code, length):
+        super(ActionUnknown, self).__init__(code, length)
+    
+    def parse(self, data):
+        if self._length > 0:
+            #print "skipping %d bytes..." % self._length
+            data.skip_bytes(self._length)
+    
+    def __repr__(self):
+        return "[ActionUnknown] Code: 0x%x, Length: %d" % (self._code, self._length)
+        
+class Action4(Action):
+    ''' Base class for SWF 4 actions '''
+    def __init__(self, code, length):
+        super(Action4, self).__init__(code, length)
+    
+    @property
+    def version(self):
+        return 4
+
+class Action5(Action):
+    ''' Base class for SWF 5 actions '''
+    def __init__(self, code, length):
+        super(Action5, self).__init__(code, length)
+
+    @property
+    def version(self):
+        return 5
+        
+class Action6(Action):
+    ''' Base class for SWF 6 actions '''
+    def __init__(self, code, length):
+        super(Action6, self).__init__(code, length)
+
+    @property
+    def version(self):
+        return 6
+        
+class Action7(Action):
+    ''' Base class for SWF 7 actions '''
+    def __init__(self, code, length):
+        super(Action7, self).__init__(code, length)
+
+    @property
+    def version(self):
+        return 7
+                
+# ========================================================= 
+# SWF 3 actions
+# =========================================================
+class ActionGetURL(Action):
+    CODE = 0x83
+    def __init__(self, code, length):
+        self.urlString = None
+        self.targetString = None
+        super(ActionGetURL, self).__init__(code, length)
+        
+    def parse(self, data):
+        self.urlString = data.readString()
+        self.targetString = data.readString()
+        
+class ActionGotoFrame(Action):
+    CODE = 0x81
+    def __init__(self, code, length):
+        self.frame = 0
+        super(ActionGotoFrame, self).__init__(code, length)
+
+    def parse(self, data): 
+        self.frame = data.readUI16()
+        
+class ActionGotoLabel(Action):
+    CODE = 0x8c
+    def __init__(self, code, length):
+        self.label = None
+        super(ActionGotoLabel, self).__init__(code, length)
+
+    def parse(self, data): 
+        self.label = data.readString()
+        
+class ActionNextFrame(Action):
+    CODE = 0x04
+    def __init__(self, code, length):
+        super(ActionNextFrame, self).__init__(code, length)
+
+class ActionPlay(Action):
+    CODE = 0x06
+    def __init__(self, code, length):
+        super(ActionPlay, self).__init__(code, length)
+    
+    def __repr__(self):
+        return "[ActionPlay] Code: 0x%x, Length: %d" % (self._code, self._length)
+            
+class ActionPreviousFrame(Action):
+    CODE = 0x05
+    def __init__(self, code, length):
+        super(ActionPreviousFrame, self).__init__(code, length)
+                
+class ActionSetTarget(Action):
+    CODE = 0x8b
+    def __init__(self, code, length):
+        self.targetName = None
+        super(ActionSetTarget, self).__init__(code, length)
+
+    def parse(self, data):
+        self.targetName = data.readString()      
+        
+class ActionStop(Action):
+    CODE = 0x07
+    def __init__(self, code, length):
+        super(ActionStop, self).__init__(code, length)
+    
+    def __repr__(self):
+        return "[ActionStop] Code: 0x%x, Length: %d" % (self._code, self._length)
+             
+class ActionStopSounds(Action):
+    CODE = 0x09
+    def __init__(self, code, length):
+        super(ActionStopSounds, self).__init__(code, length)   
+        
+class ActionToggleQuality(Action):
+    CODE = 0x08
+    def __init__(self, code, length):
+        super(ActionToggleQuality, self).__init__(code, length)
+        
+class ActionWaitForFrame(Action):
+    CODE = 0x8a
+    def __init__(self, code, length):
+        self.frame = 0
+        self.skipCount = 0
+        super(ActionWaitForFrame, self).__init__(code, length)
+
+    def parse(self, data):
+        self.frame = data.readUI16()
+        self.skipCount = data.readUI8()
+                              
+# ========================================================= 
+# SWF 4 actions
+# =========================================================
+class ActionAdd(Action4):
+    CODE = 0x0a
+    def __init__(self, code, length):
+        super(ActionAdd, self).__init__(code, length)
+
+class ActionAnd(Action4):
+    CODE = 0x10
+    def __init__(self, code, length):
+        super(ActionAnd, self).__init__(code, length)
+                       
+# urgh! some 100 to go...
+
+ActionTable = {}
+for name, value in dict(locals()).items():
+    if type(value) == type and issubclass(value, Action) and hasattr(value, 'CODE'):
+        ActionTable[value.CODE] = value
+
+class SWFActionFactory(object):
+    @classmethod
+    def create(cls, code, length):
+        return ActionTable.get(code, ActionUnknown)(code, length)
+        

+ 344 - 0
format_convert/swf/consts.py

@@ -0,0 +1,344 @@
+
+class Enum(object):
+    @classmethod
+    def tostring(cls, type):
+        return cls._mapping.get(type, 'unknown')
+
+class BitmapFormat(Enum):
+    BIT_8 = 3
+    BIT_15 = 4
+    BIT_24 = 5
+    
+    _mapping = {
+        BIT_8: 'BIT_8',
+        BIT_15: 'BIT_15',
+        BIT_24: 'BIT_24',
+    }
+
+class BitmapType(Enum):
+    JPEG = 1  
+    GIF89A = 2
+    PNG = 3
+    
+    _mapping = {
+        JPEG: 'JPEG',
+        GIF89A: 'GIF89A',
+        PNG: 'PNG',
+    }
+    
+    FileExtensions = {
+        JPEG: '.jpeg',
+        GIF89A: '.gif',
+        PNG: '.png'
+    }
+
+class GradientSpreadMode(Enum):
+    PAD = 0 
+    REFLECT = 1
+    REPEAT = 2
+    
+    _mapping = {
+        PAD: 'pad',
+        REFLECT: 'reflect',
+        REPEAT: 'repeat',
+    }
+
+class GradientType(Enum):
+    LINEAR = 1
+    RADIAL = 2
+    
+    _mapping = {
+        LINEAR: 'LINEAR',
+        RADIAL: 'RADIAL',
+    }
+                
+class LineScaleMode(Enum):
+    NONE = 0
+    HORIZONTAL = 1 
+    NORMAL = 2
+    VERTICAL = 3
+    
+    _mapping = {
+        NONE: 'none',
+        HORIZONTAL: 'horizontal',
+        NORMAL: 'normal',
+        VERTICAL: 'vertical',
+    }
+                        
+class SpreadMethod(Enum):
+    PAD = 0 
+    REFLECT = 1
+    REPEAT = 2
+    
+    _mapping = {
+        PAD: 'pad',
+        REFLECT: 'reflect',
+        REPEAT: 'repeat',
+    }
+                
+class InterpolationMethod(Enum):
+    RGB = 0
+    LINEAR_RGB = 1
+    
+    _mapping = {
+        RGB: 'RGB',
+        LINEAR_RGB: 'LINEAR_RGB',
+    }
+                        
+class LineJointStyle(Enum):
+    ROUND = 0
+    BEVEL = 1
+    MITER = 2
+    
+    _mapping = {
+        ROUND: 'ROUND',
+        BEVEL: 'BEVEL',
+        MITER: 'MITER',
+    }
+        
+class LineCapsStyle(Enum):
+    ROUND = 0
+    NO = 1
+    SQUARE = 2
+    
+    _mapping = {
+        ROUND: 'ROUND',
+        NO: 'NO',
+        SQUARE: 'SQUARE',
+    }
+        
+class TextAlign(Enum):
+    LEFT = 0
+    RIGHT = 1
+    CENTER = 2
+    JUSTIFY = 3
+    
+    _mapping = {
+        LEFT: 'left',
+        RIGHT: 'right',
+        CENTER: 'center',
+        JUSTIFY: 'justify',
+    }
+        
+class BlendMode(Enum):
+    Normal = 0
+    Normal_1 = 1
+    Layer = 2
+    Multiply = 3
+    Screen = 4
+    Lighten = 5
+    Darken = 6
+    Difference = 7
+    Add = 8
+    Subtract = 9
+    Invert = 10
+    Alpha = 11
+    Erase = 12
+    Overlay = 13
+    Hardlight = 14
+    
+    _mapping = {
+        Normal: "Normal",
+        Normal_1: "Normal",
+        Layer: "Layer",
+        Multiply: "Multiply",
+        Screen: "Screen",
+        Lighten: "Lighten",
+        Darken: "Darken",
+        Difference: "Difference",
+        Add: "Add",
+        Subtract: "Subtract",
+        Invert: "Invert",
+        Alpha: "Alpha",
+        Erase: "Erase",
+        Overlay: "Overlay",
+        Hardlight: "Hardlight",
+    }
+
+class AudioSampleRate(Enum):
+    Hz5k512 = 0
+    Hz11k025 = 1
+    Hz22k05 = 2
+    Hz44k1 = 3
+    
+    _mapping = {
+        Hz5k512: '5.512kHz',
+        Hz11k025: '11.025kHz',
+        Hz22k05: '22.05kHz',
+        Hz44k1: '44.1kHz',
+    }
+    
+    Rates = {
+        Hz5k512: 5512,
+        Hz11k025: 11025,
+        Hz22k05: 22050,
+        Hz44k1: 44100,
+    }
+
+class AudioChannels(Enum):
+    Mono = 0
+    Stereo = 1
+    
+    _mapping = {
+        Mono: 'Mono',
+        Stereo: 'Stereo',
+    }
+    
+    Channels = {
+        Mono: 1,
+        Stereo: 2
+    }
+
+class AudioSampleSize(Enum):
+    b8 = 0
+    b16 = 1
+    
+    _mapping = {
+        b8: '8-bit',
+        b16: '16-bit',
+    }
+    
+    Bits = {
+        b8: 8,
+        b16: 16
+    }
+
+class AudioCodec(Enum):
+    UncompressedNativeEndian = 0
+    ADPCM = 1
+    MP3 = 2
+    UncompressedLittleEndian = 3
+    Nellymoser16kHz = 4
+    Nellymoser8kHz = 5
+    Nellymoser = 6
+    Speex = 11
+    
+    _mapping = {
+        UncompressedNativeEndian: 'UncompressedNativeEndian',
+        ADPCM: 'ADPCM',
+        MP3: 'MP3',
+        UncompressedLittleEndian: 'UncompressedLittleEndian',
+        Nellymoser16kHz: 'Nellymoser16kHz',
+        Nellymoser8kHz: 'Nellymoser8kHz',
+        Nellymoser: 'Nellymoser',
+        Speex: 'Speex',
+    }
+    
+    MinimumVersions = {
+        UncompressedNativeEndian: 1,
+        ADPCM: 1,
+        MP3: 4,
+        UncompressedLittleEndian: 4,
+        Nellymoser16kHz: 10,
+        Nellymoser8kHz: 10,
+        Nellymoser: 6,
+        Speex: 10,
+    }
+    
+    FileExtensions = {
+        MP3: '.mp3',
+        
+        # arbitrary container
+        UncompressedNativeEndian: '.wav',   
+        UncompressedLittleEndian: '.wav',
+        ADPCM: '.wav',
+        
+        # fictitious
+        Nellymoser16kHz: '.nel',
+        Nellymoser8kHz: '.nel',
+        Nellymoser: '.nel',
+        Speex: '.spx',
+    }
+    
+    MimeTypes = {
+        MP3: 'audio/mpeg',
+        UncompressedNativeEndian: 'audio/wav',   
+        UncompressedLittleEndian: 'audio/wav',
+        ADPCM: 'audio/wav',
+        
+        # assume ogg container?
+        Speex: 'audio/ogg',
+        
+        # punt
+        Nellymoser16kHz: 'application/octet-stream',
+        Nellymoser8kHz: 'application/octet-stream',
+        Nellymoser: 'application/octet-stream',
+    }
+
+class ProductEdition(Enum):
+    DeveloperEdition = 0
+    FullCommercialEdition = 1
+    NonCommercialEdition = 2
+    EducationalEdition = 3
+    NotForResaleEdition = 4
+    TrialEdition = 5
+    NoEdition = 6
+    
+    _mapping = {
+        DeveloperEdition: 'Developer edition',
+        FullCommercialEdition: 'Full commercial',
+        NonCommercialEdition: 'Non-commercial',
+        EducationalEdition: 'Educational',
+        NotForResaleEdition: 'Not for resale',
+        TrialEdition: 'Trial',
+        NoEdition: 'None',
+    }
+
+class ProductKind(Enum):
+    Unknown = 0
+    FlexForJ2EE = 1
+    FlexForDotNET = 2
+    AdobeFlex = 3
+    
+    _mapping = {
+        Unknown: 'Unknown',
+        FlexForJ2EE: 'Flex for J2EE',
+        FlexForDotNET: 'Flex for .NET',
+        AdobeFlex: 'Adobe Flex',
+    }
+
+class VideoCodec(Enum):
+    SorensonH263 = 2
+    ScreenVideo = 3
+    VP6 = 4
+    VP6Alpha = 5
+    
+    _mapping = {
+        SorensonH263: 'Sorenson H.263',
+        ScreenVideo: 'Screen video',
+        VP6: 'VP6',
+        VP6Alpha: 'VP6 with alpha',
+    }
+    
+    MinimumVersions = {
+        SorensonH263: 6,
+        ScreenVideo: 7,
+        VP6: 8,
+        VP6Alpha: 8,
+    }
+
+class MPEGVersion(Enum):
+    MPEG2_5 = 0
+    RFU = 1
+    MPEG2 = 2
+    MPEG1 = 3
+    
+    _mapping = {
+        MPEG2_5: 'MPEG2.5',
+        RFU: 'Reserved',
+        MPEG2: 'MPEG2',
+        MPEG1: 'MPEG1',
+    }
+
+class MPEGLayer(Enum):
+    RFU = 0
+    Layer3 = 1
+    Layer2 = 2
+    Layer1 = 3
+    
+    _mapping = {
+        RFU: 'Reserved',
+        Layer3: 'Layer 3',
+        Layer2: 'Layer 2',
+        Layer1: 'Layer 1',
+    }

+ 1437 - 0
format_convert/swf/data.py

@@ -0,0 +1,1437 @@
+from .consts import *
+from .utils import *
+
+class _dumb_repr(object):
+    def __repr__(self):
+        return '<%s %r>' % (self.__class__.__name__, self.__dict__)
+
+class SWFRawTag(_dumb_repr):
+    def __init__(self, s=None):
+        if not s is None:
+            self.parse(s)
+
+    def parse(self, s):
+        pos = s.tell()
+        self.header = s.readtag_header()
+        self.pos_content = s.tell()
+        s.f.seek(pos)
+        #self.bytes = s.f.read(self.header.tag_length())
+        #s.f.seek(self.pos_content)
+
+class SWFStraightEdge(_dumb_repr):
+    def __init__(self, start, to, line_style_idx, fill_style_idx):
+        self.start = start
+        self.to = to
+        self.line_style_idx = line_style_idx
+        self.fill_style_idx = fill_style_idx
+
+    def reverse_with_new_fillstyle(self, new_fill_idx):
+        return SWFStraightEdge(self.to, self.start, self.line_style_idx, new_fill_idx)
+
+class SWFCurvedEdge(SWFStraightEdge):
+    def __init__(self, start, control, to, line_style_idx, fill_style_idx):
+        super(SWFCurvedEdge, self).__init__(start, to, line_style_idx, fill_style_idx)
+        self.control = control
+
+    def reverse_with_new_fillstyle(self, new_fill_idx):
+        return SWFCurvedEdge(self.to, self.control, self.start, self.line_style_idx, new_fill_idx)
+
+class SWFShape(_dumb_repr):
+    def __init__(self, data=None, level=1, unit_divisor=20.0):
+        self._records = []
+        self._fillStyles = []
+        self._lineStyles = []
+        self._postLineStyles = {}
+        self._edgeMapsCreated = False
+        self.unit_divisor = unit_divisor
+        self.fill_edge_maps = []
+        self.line_edge_maps = []
+        self.current_fill_edge_map = {}
+        self.current_line_edge_map = {}
+        self.num_groups = 0
+        self.coord_map = {}
+        if not data is None:
+            self.parse(data, level)
+
+    def get_dependencies(self):
+        s = set()
+        for x in self._fillStyles:
+            s.update(x.get_dependencies())
+        for x in self._lineStyles:
+            s.update(x.get_dependencies())
+        return s
+
+    def parse(self, data, level=1):
+        data.reset_bits_pending()
+        fillbits = data.readUB(4)
+        linebits = data.readUB(4)
+        self.read_shape_records(data, fillbits, linebits, level)
+
+    def export(self, handler=None):
+        self._create_edge_maps()
+        if handler is None:
+            from export import SVGShapeExporter
+            handler = SVGShapeExporter()
+        handler.begin_shape()
+        for i in range(0, self.num_groups):
+            self._export_fill_path(handler, i)
+            self._export_line_path(handler, i)
+        handler.end_shape()
+        return handler
+
+    @property
+    def records(self):
+        return self._records
+
+    def read_shape_records(self, data, fill_bits, line_bits, level=1):
+        shape_record = None
+        record_id = 0
+        while type(shape_record) != SWFShapeRecordEnd:
+            # The SWF10 spec says that shape records are byte aligned.
+            # In reality they seem not to be?
+            # bitsPending = 0;
+            edge_record = (data.readUB(1) == 1)
+            if edge_record:
+                straight_flag = (data.readUB(1) == 1)
+                num_bits = data.readUB(4) + 2
+                if straight_flag:
+                    shape_record = data.readSTRAIGHTEDGERECORD(num_bits)
+                else:
+                    shape_record = data.readCURVEDEDGERECORD(num_bits)
+            else:
+                states= data.readUB(5)
+                if states == 0:
+                    shape_record = SWFShapeRecordEnd()
+                else:
+                    style_change_record = data.readSTYLECHANGERECORD(states, fill_bits, line_bits, level)
+                    if style_change_record.state_new_styles:
+                        fill_bits = style_change_record.num_fillbits
+                        line_bits = style_change_record.num_linebits
+                    shape_record = style_change_record
+            shape_record.record_id = record_id
+            self._records.append(shape_record)
+            record_id += 1
+            #print shape_record.tostring()
+
+    def _create_edge_maps(self):
+        if self._edgeMapsCreated:
+            return
+        xPos = 0
+        yPos = 0
+        sub_path = []
+        fs_offset = 0
+        ls_offset = 0
+        curr_fs_idx0 = 0
+        curr_fs_idx1 = 0
+        curr_ls_idx = 0
+
+        self.fill_edge_maps = []
+        self.line_edge_maps = []
+        self.current_fill_edge_map = {}
+        self.current_line_edge_map = {}
+        self.num_groups = 0
+
+        for i in range(0, len(self._records)):
+            rec = self._records[i]
+            if rec.type == SWFShapeRecord.TYPE_STYLECHANGE:
+                if rec.state_line_style or rec.state_fill_style0 or rec.state_fill_style1:
+                    if len(sub_path):
+                        self._process_sub_path(sub_path, curr_ls_idx, curr_fs_idx0, curr_fs_idx1, rec.record_id)
+                    sub_path = []
+
+                if rec.state_new_styles:
+                    fs_offset = len(self._fillStyles)
+                    ls_offset = len(self._lineStyles)
+                    self._append_to(self._fillStyles, rec.fill_styles)
+                    self._append_to(self._lineStyles, rec.line_styles)
+
+                if rec.state_line_style and rec.state_fill_style0 and rec.state_fill_style1 and \
+                    rec.line_style == 0 and rec.fill_style0 == 0 and rec.fill_style1 == 0:
+                    # new group (probably)
+                    self._clean_edge_map(self.current_fill_edge_map)
+                    self._clean_edge_map(self.current_line_edge_map)
+                    self.fill_edge_maps.append(self.current_fill_edge_map)
+                    self.line_edge_maps.append(self.current_line_edge_map)
+                    self.current_fill_edge_map = {}
+                    self.current_line_edge_map = {}
+                    self.num_groups += 1
+                    curr_fs_idx0 = 0
+                    curr_fs_idx1 = 0
+                    curr_ls_idx = 0
+                else:
+                    if rec.state_line_style:
+                        curr_ls_idx = rec.line_style
+                        if curr_ls_idx > 0:
+                            curr_ls_idx += ls_offset
+                    if rec.state_fill_style0:
+                        curr_fs_idx0 = rec.fill_style0
+                        if curr_fs_idx0 > 0:
+                            curr_fs_idx0 += fs_offset
+                    if rec.state_fill_style1:
+                        curr_fs_idx1 = rec.fill_style1
+                        if curr_fs_idx1 > 0:
+                            curr_fs_idx1 += fs_offset
+
+                if rec.state_moveto:
+                    xPos = rec.move_deltaX
+                    yPos = rec.move_deltaY
+            elif rec.type == SWFShapeRecord.TYPE_STRAIGHTEDGE:
+                start = [NumberUtils.round_pixels_400(xPos), NumberUtils.round_pixels_400(yPos)]
+                if rec.general_line_flag:
+                    xPos += rec.deltaX
+                    yPos += rec.deltaY
+                else:
+                    if rec.vert_line_flag:
+                        yPos += rec.deltaY
+                    else:
+                        xPos += rec.deltaX
+                to = [NumberUtils.round_pixels_400(xPos), NumberUtils.round_pixels_400(yPos)]
+                sub_path.append(SWFStraightEdge(start, to, curr_ls_idx, curr_fs_idx1))
+            elif rec.type == SWFShapeRecord.TYPE_CURVEDEDGE:
+                start = [NumberUtils.round_pixels_400(xPos), NumberUtils.round_pixels_400(yPos)]
+                xPosControl = xPos + rec.control_deltaX
+                yPosControl = yPos + rec.control_deltaY
+                xPos = xPosControl + rec.anchor_deltaX
+                yPos = yPosControl + rec.anchor_deltaY
+                control = [xPosControl, yPosControl]
+                to = [NumberUtils.round_pixels_400(xPos), NumberUtils.round_pixels_400(yPos)]
+                sub_path.append(SWFCurvedEdge(start, control, to, curr_ls_idx, curr_fs_idx1))
+            elif rec.type == SWFShapeRecord.TYPE_END:
+                # We're done. Process the last subpath, if any
+                if len(sub_path) > 0:
+                    self._process_sub_path(sub_path, curr_ls_idx, curr_fs_idx0, curr_fs_idx1, rec.record_id)
+                    self._clean_edge_map(self.current_fill_edge_map)
+                    self._clean_edge_map(self.current_line_edge_map)
+                    self.fill_edge_maps.append(self.current_fill_edge_map)
+                    self.line_edge_maps.append(self.current_line_edge_map)
+                    self.current_fill_edge_map = {}
+                    self.current_line_edge_map = {}
+                    self.num_groups += 1
+                curr_fs_idx0 = 0
+                curr_fs_idx1 = 0
+                curr_ls_idx = 0
+
+        self._edgeMapsCreated = True
+
+    def _process_sub_path(self, sub_path, linestyle_idx, fillstyle_idx0, fillstyle_idx1, record_id=-1):
+        path = None
+        if fillstyle_idx0 != 0:
+            if not fillstyle_idx0 in self.current_fill_edge_map:
+                path = self.current_fill_edge_map[fillstyle_idx0] = []
+            else:
+                path = self.current_fill_edge_map[fillstyle_idx0]
+            for j in range(len(sub_path) - 1, -1, -1):
+                path.append(sub_path[j].reverse_with_new_fillstyle(fillstyle_idx0))
+
+        if fillstyle_idx1 != 0:
+            if not fillstyle_idx1 in self.current_fill_edge_map:
+                path = self.current_fill_edge_map[fillstyle_idx1] = []
+            else:
+                path = self.current_fill_edge_map[fillstyle_idx1]
+            self._append_to(path, sub_path)
+
+        if linestyle_idx != 0:
+            if not linestyle_idx in self.current_line_edge_map:
+                path = self.current_line_edge_map[linestyle_idx] = []
+            else:
+                path = self.current_line_edge_map[linestyle_idx]
+            self._append_to(path, sub_path)
+
+    def _clean_edge_map(self, edge_map):
+        for style_idx in edge_map:
+            sub_path = edge_map[style_idx] if style_idx in edge_map else None
+            if sub_path is not None and len(sub_path) > 0:
+                tmp_path = []
+                prev_edge = None
+                self._create_coord_map(sub_path)
+                while len(sub_path) > 0:
+                    idx = 0
+                    while idx < len(sub_path):
+                        if prev_edge is None or self._equal_point(prev_edge.to, sub_path[idx].start):
+                            edge = sub_path[idx]
+                            del sub_path[idx]
+                            tmp_path.append(edge)
+                            self._remove_edge_from_coord_map(edge)
+                            prev_edge = edge
+                        else:
+                            edge = self._find_next_edge_in_coord_map(prev_edge)
+                            if not edge is None:
+                                idx = sub_path.index(edge)
+                            else:
+                                idx = 0
+                                prev_edge = None
+                edge_map[style_idx] = tmp_path
+
+    def _equal_point(self, a, b, tol=0.001):
+        return (a[0] > b[0]-tol and a[0] < b[0]+tol and a[1] > b[1]-tol and a[1] < b[1]+tol)
+
+    def _find_next_edge_in_coord_map(self, edge):
+        key = "%0.4f_%0.4f" % (edge.to[0], edge.to[1])
+        if key in self.coord_map and len(self.coord_map[key]) > 0:
+            return self.coord_map[key][0]
+        else:
+            return None
+
+    def _create_coord_map(self, path):
+        self.coord_map = {}
+        for i in range(0, len(path)):
+            start = path[i].start
+            key = "%0.4f_%0.4f" % (start[0], start[1])
+            coord_map_array = self.coord_map[key] if key in self.coord_map else None
+            if coord_map_array is None:
+                self.coord_map[key] = [path[i]]
+            else:
+                self.coord_map[key].append(path[i])
+
+    def _remove_edge_from_coord_map(self, edge):
+        key = "%0.4f_%0.4f" % (edge.start[0], edge.start[1])
+        if key in self.coord_map:
+            coord_map_array = self.coord_map[key]
+            if len(coord_map_array) == 1:
+                del self.coord_map[key]
+            else:
+                try:
+                    idx = coord_map_array.index(edge)
+                    del coord_map_array[idx]
+                except:
+                    pass
+
+    def _create_path_from_edge_map(self, edge_map):
+        new_path = []
+        style_ids = []
+        for style_id in edge_map:
+            style_ids.append(int(style_id))
+        style_ids = sorted(style_ids)
+        for i in range(0, len(style_ids)):
+            self._append_to(new_path, edge_map[style_ids[i]])
+        return new_path
+
+    def _export_fill_path(self, handler, group_index):
+        path = self._create_path_from_edge_map(self.fill_edge_maps[group_index])
+
+        pos = [100000000, 100000000]
+        u = 1.0 / self.unit_divisor
+        fill_style_idx = 10000000
+
+        if len(path) < 1:
+            return
+        handler.begin_fills()
+        for i in range(0, len(path)):
+            e = path[i]
+            if fill_style_idx != e.fill_style_idx:
+                fill_style_idx = e.fill_style_idx
+                pos = [100000000, 100000000]
+                try:
+                    fill_style = self._fillStyles[fill_style_idx - 1] if fill_style_idx > 0 else None
+                    if fill_style.type == 0x0:
+                        # solid fill
+                        handler.begin_fill(
+                            ColorUtils.rgb(fill_style.rgb),
+                            ColorUtils.alpha(fill_style.rgb))
+                    elif fill_style.type in [0x10, 0x12, 0x13]:
+                        # gradient fill
+                        colors = []
+                        ratios = []
+                        alphas = []
+                        for j in range(0, len(fill_style.gradient.records)):
+                            gr = fill_style.gradient.records[j]
+                            colors.append(ColorUtils.rgb(gr.color))
+                            ratios.append(gr.ratio)
+                            alphas.append(ColorUtils.alpha(gr.color))
+                        handler.begin_gradient_fill(
+                            GradientType.LINEAR if fill_style.type == 0x10 else GradientType.RADIAL,
+                            colors, alphas, ratios,
+                            fill_style.gradient_matrix,
+                            fill_style.gradient.spreadmethod,
+                            fill_style.gradient.interpolation_mode,
+                            fill_style.gradient.focal_point
+                            )
+                    elif fill_style.type in [0x40, 0x41, 0x42, 0x43]:
+                        # bitmap fill
+                        handler.begin_bitmap_fill(
+                            fill_style.bitmap_id,
+                            fill_style.bitmap_matrix,
+                            (fill_style.type == 0x40 or fill_style.type == 0x42),
+                            (fill_style.type == 0x40 or fill_style.type == 0x41)
+                            )
+                        pass
+                except:
+                    # Font shapes define no fillstyles per se, but do reference fillstyle index 1,
+                    # which represents the font color. We just report solid black in this case.
+                    handler.begin_fill(0)
+
+            if not self._equal_point(pos, e.start):
+                handler.move_to(e.start[0] * u, e.start[1] * u)
+
+            if type(e) is SWFCurvedEdge:
+                handler.curve_to(e.control[0] * u, e.control[1] * u, e.to[0] * u, e.to[1] * u)
+            else:
+                handler.line_to(e.to[0] * u, e.to[1] * u)
+
+            pos = e.to
+
+        handler.end_fill()
+        handler.end_fills()
+
+    def _export_line_path(self, handler, group_index):
+
+        path = self._create_path_from_edge_map(self.line_edge_maps[group_index])
+        pos = [100000000, 100000000]
+        u = 1.0 / self.unit_divisor
+        line_style_idx = 10000000
+        line_style = None
+        if len(path) < 1:
+            return
+
+        handler.begin_lines()
+        for i in range(0, len(path)):
+            e = path[i]
+
+            if line_style_idx != e.line_style_idx:
+                line_style_idx = e.line_style_idx
+                pos = [100000000, 100000000]
+                try:
+                    line_style = self._lineStyles[line_style_idx - 1]
+                except:
+                    line_style = None
+                if line_style is not None:
+                    scale_mode = LineScaleMode.NORMAL
+                    if line_style.no_hscale_flag and line_style.no_vscale_flag:
+                        scale_mode = LineScaleMode.NONE
+                    elif line_style.no_hscale_flag:
+                        scale_mode = LineScaleMode.HORIZONTAL
+                    elif line_style.no_hscale_flag:
+                        scale_mode = LineScaleMode.VERTICAL
+
+                    if not line_style.has_fill_flag:
+                        handler.line_style(
+                            line_style.width / 20.0,
+                            ColorUtils.rgb(line_style.color),
+                            ColorUtils.alpha(line_style.color),
+                            line_style.pixelhinting_flag,
+                            scale_mode,
+                            line_style.start_caps_style,
+                            line_style.end_caps_style,
+                            line_style.joint_style,
+                            line_style.miter_limit_factor)
+                    else:
+                        fill_style = line_style.fill_type
+
+                        if fill_style.type in [0x10, 0x12, 0x13]:
+                            # gradient fill
+                            colors = []
+                            ratios = []
+                            alphas = []
+                            for j in range(0, len(fill_style.gradient.records)):
+                                gr = fill_style.gradient.records[j]
+                                colors.append(ColorUtils.rgb(gr.color))
+                                ratios.append(gr.ratio)
+                                alphas.append(ColorUtils.alpha(gr.color))
+
+                            handler.line_gradient_style(
+                                line_style.width / 20.0,
+                                line_style.pixelhinting_flag,
+                                scale_mode,
+                                line_style.start_caps_style,
+                                line_style.end_caps_style,
+                                line_style.joint_style,
+                                line_style.miter_limit_factor,
+                                GradientType.LINEAR if fill_style.type == 0x10 else GradientType.RADIAL,
+                                colors, alphas, ratios,
+                                fill_style.gradient_matrix,
+                                fill_style.gradient.spreadmethod,
+                                fill_style.gradient.interpolation_mode,
+                                fill_style.gradient.focal_point
+                                )
+                        elif fill_style.type in [0x40, 0x41, 0x42]:
+                            handler.line_bitmap_style(
+                                line_style.width / 20.0,
+                                line_style.pixelhinting_flag,
+                                scale_mode,
+                                line_style.start_caps_style,
+                                line_style.end_caps_style,
+                                line_style.joint_style,
+                                line_style.miter_limit_factor,
+                                fill_style.bitmap_id, fill_style.bitmap_matrix,
+                                (fill_style.type == 0x40 or fill_style.type == 0x42),
+                                (fill_style.type == 0x40 or fill_style.type == 0x41)
+                                )
+                else:
+                    # we should never get here
+                    handler.line_style(0)
+            if not self._equal_point(pos, e.start):
+                handler.move_to(e.start[0] * u, e.start[1] * u)
+            if type(e) is SWFCurvedEdge:
+                handler.curve_to(e.control[0] * u, e.control[1] * u, e.to[0] * u, e.to[1] * u)
+            else:
+                handler.line_to(e.to[0] * u, e.to[1] * u)
+            pos = e.to
+        handler.end_lines()
+
+    def _append_to(self, v1, v2):
+        for i in range(0, len(v2)):
+            v1.append(v2[i])
+
+    def __str__(self):
+        return "[SWFShape]"
+
+class SWFShapeWithStyle(SWFShape):
+    def __init__(self, data, level, unit_divisor):
+        self._initialFillStyles = []
+        self._initialLineStyles = []
+        super(SWFShapeWithStyle, self).__init__(data, level, unit_divisor)
+
+    def export(self, handler=None):
+        self._fillStyles.extend(self._initialFillStyles)
+        self._lineStyles.extend(self._initialLineStyles)
+        return super(SWFShapeWithStyle, self).export(handler)
+
+    def get_dependencies(self):
+        s = set()
+        for x in self._fillStyles + self._initialFillStyles:
+            s.update(x.get_dependencies())
+        for x in self._lineStyles + self._initialLineStyles:
+            s.update(x.get_dependencies())
+        return s
+
+    def parse(self, data, level=1):
+
+        data.reset_bits_pending()
+        num_fillstyles = self.readstyle_array_length(data, level)
+        for i in range(0, num_fillstyles):
+            self._initialFillStyles.append(data.readFILLSTYLE(level))
+        num_linestyles = self.readstyle_array_length(data, level)
+        for i in range(0, num_linestyles):
+            if level <= 3:
+                self._initialLineStyles.append(data.readLINESTYLE(level))
+            else:
+                self._initialLineStyles.append(data.readLINESTYLE2(level))
+        num_fillbits = data.readUB(4)
+        num_linebits = data.readUB(4)
+        data.reset_bits_pending()
+        self.read_shape_records(data, num_fillbits, num_linebits, level)
+
+    def readstyle_array_length(self, data, level=1):
+        length = data.readUI8()
+        if level >= 2 and length == 0xff:
+            length = data.readUI16()
+        return length
+
+    def __str__(self):
+        s = "    FillStyles:\n" if len(self._fillStyles) > 0 else ""
+        for i in range(0, len(self._initialFillStyles)):
+            s += "        %d:%s\n" % (i+1, self._initialFillStyles[i].__str__())
+        if len(self._initialLineStyles) > 0:
+            s += "    LineStyles:\n"
+            for i in range(0, len(self._initialLineStyles)):
+                s += "        %d:%s\n" % (i+1, self._initialLineStyles[i].__str__())
+        for record in self._records:
+            s += record.__str__() + '\n'
+        return s.rstrip() + super(SWFShapeWithStyle, self).__str__()
+
+class SWFShapeRecord(_dumb_repr):
+
+    TYPE_UNKNOWN = 0
+    TYPE_END = 1
+    TYPE_STYLECHANGE = 2
+    TYPE_STRAIGHTEDGE = 3
+    TYPE_CURVEDEDGE = 4
+
+    record_id = -1
+
+    def __init__(self, data=None, level=1):
+        if not data is None:
+            self.parse(data, level)
+
+    @property
+    def is_edge_record(self):
+        return (self.type == SWFShapeRecord.TYPE_STRAIGHTEDGE or
+            self.type == SWFShapeRecord.TYPE_CURVEDEDGE)
+
+    def parse(self, data, level=1):
+        pass
+
+    @property
+    def type(self):
+        return SWFShapeRecord.TYPE_UNKNOWN
+
+    def __str__(self):
+        return "    [SWFShapeRecord]"
+
+class SWFShapeRecordStraightEdge(SWFShapeRecord):
+    def __init__(self, data, num_bits=0, level=1):
+        self.num_bits = num_bits
+        super(SWFShapeRecordStraightEdge, self).__init__(data, level)
+
+    def parse(self, data, level=1):
+        self.general_line_flag = (data.readUB(1) == 1)
+        self.vert_line_flag = False if self.general_line_flag else (data.readUB(1) == 1)
+        self.deltaX = data.readSB(self.num_bits) \
+            if self.general_line_flag or not self.vert_line_flag \
+            else 0.0
+        self.deltaY = data.readSB(self.num_bits) \
+            if self.general_line_flag or self.vert_line_flag \
+            else 0.0
+
+    @property
+    def type(self):
+        return SWFShapeRecord.TYPE_STRAIGHTEDGE
+
+    def __str__(self):
+        s = "    [SWFShapeRecordStraightEdge]"
+        if self.general_line_flag:
+            s += " General: %d %d" % (self.deltaX, self.deltaY)
+        else:
+            if self.vert_line_flag:
+                s += " Vertical: %d" % self.deltaY
+            else:
+                s += " Horizontal: %d" % self.deltaX
+        return s
+
+class SWFShapeRecordCurvedEdge(SWFShapeRecord):
+    def __init__(self, data, num_bits=0, level=1):
+        self.num_bits = num_bits
+        super(SWFShapeRecordCurvedEdge, self).__init__(data, level)
+
+    def parse(self, data, level=1):
+        self.control_deltaX = data.readSB(self.num_bits)
+        self.control_deltaY = data.readSB(self.num_bits)
+        self.anchor_deltaX = data.readSB(self.num_bits)
+        self.anchor_deltaY = data.readSB(self.num_bits)
+
+    @property
+    def type(self):
+        return SWFShapeRecord.TYPE_CURVEDEDGE
+
+    def __str__(self):
+        return "    [SWFShapeRecordCurvedEdge]" + \
+            " ControlDelta: %d, %d" % (self.control_deltaX, self.control_deltaY) + \
+            " AnchorDelta: %d, %d" % (self.anchor_deltaX, self.anchor_deltaY)
+
+class SWFShapeRecordStyleChange(SWFShapeRecord):
+    def __init__(self, data, states=0, fill_bits=0, line_bits=0, level=1):
+        self.fill_styles = []
+        self.line_styles = []
+        self.state_new_styles = ((states & 0x10) != 0)
+        self.state_line_style = ((states & 0x08) != 0)
+        self.state_fill_style1 = ((states & 0x4) != 0)
+        self.state_fill_style0 = ((states & 0x2) != 0)
+        self.state_moveto = ((states & 0x1) != 0)
+        self.num_fillbits = fill_bits
+        self.num_linebits = line_bits
+        self.move_deltaX = 0.0
+        self.move_deltaY = 0.0
+        self.fill_style0 = 0
+        self.fill_style1 = 0
+        self.line_style = 0
+        super(SWFShapeRecordStyleChange, self).__init__(data, level)
+
+    def parse(self, data, level=1):
+
+        if self.state_moveto:
+            movebits = data.readUB(5)
+            self.move_deltaX = data.readSB(movebits)
+            self.move_deltaY = data.readSB(movebits)
+        self.fill_style0 = data.readUB(self.num_fillbits) if self.state_fill_style0 else 0
+        self.fill_style1 = data.readUB(self.num_fillbits) if self.state_fill_style1 else 0
+        self.line_style = data.readUB(self.num_linebits) if self.state_line_style else 0
+        if self.state_new_styles:
+            data.reset_bits_pending();
+            num_fillstyles = self.readstyle_array_length(data, level)
+            for i in range(0, num_fillstyles):
+                self.fill_styles.append(data.readFILLSTYLE(level))
+            num_linestyles = self.readstyle_array_length(data, level)
+            for i in range(0, num_linestyles):
+                if level <= 3:
+                    self.line_styles.append(data.readLINESTYLE(level))
+                else:
+                    self.line_styles.append(data.readLINESTYLE2(level))
+            self.num_fillbits = data.readUB(4)
+            self.num_linebits = data.readUB(4)
+
+    @property
+    def type(self):
+        return SWFShapeRecord.TYPE_STYLECHANGE
+
+    def readstyle_array_length(self, data, level=1):
+        length = data.readUI8()
+        if level >= 2 and length == 0xff:
+            length = data.readUI16()
+        return length
+
+    def __str__(self):
+        return "    [SWFShapeRecordStyleChange]" + \
+            " moveTo: %d %d" % (self.move_deltaX, self.move_deltaY) + \
+            " fs0: %d" % self.fill_style0 + \
+            " fs1: %d" % self.fill_style1 + \
+            " linestyle: %d" % self.line_style + \
+            " flags: %d %d %d" % (self.state_fill_style0, self.state_fill_style1, self.state_line_style)
+
+class SWFShapeRecordEnd(SWFShapeRecord):
+    def __init__(self):
+        super(SWFShapeRecordEnd, self).__init__(None)
+
+    def parse(self, data, level=1):
+        pass
+
+    @property
+    def type(self):
+        return SWFShapeRecord.TYPE_END
+
+    def __str__(self):
+        return "    [SWFShapeRecordEnd]"
+
+class SWFMatrix(_dumb_repr):
+    def __init__(self, data):
+        self.scaleX = 1.0
+        self.scaleY = 1.0
+        self.rotateSkew0 = 0.0
+        self.rotateSkew1 = 0.0
+        self.translateX = 0.0
+        self.translateY = 0.0
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        data.reset_bits_pending();
+        self.scaleX = 1.0
+        self.scaleY = 1.0
+        if data.readUB(1) == 1:
+            scaleBits = data.readUB(5)
+            self.scaleX = data.readFB(scaleBits)
+            self.scaleY = data.readFB(scaleBits)
+        self.rotateSkew0 = 0.0
+        self.rotateSkew1 = 0.0
+        if data.readUB(1) == 1:
+            rotateBits = data.readUB(5)
+            self.rotateSkew0 = data.readFB(rotateBits)
+            self.rotateSkew1 = data.readFB(rotateBits)
+        translateBits = data.readUB(5)
+        self.translateX = data.readSB(translateBits)
+        self.translateY = data.readSB(translateBits)
+
+    def to_array(self):
+        return [
+            self.scaleX, self.rotateSkew0,
+            self.rotateSkew1, self.scaleY,
+            self.translateX, self.translateY
+        ]
+
+    def __str__(self):
+        def fmt(s):
+            return "%0.2f" % s
+
+        return "[%s]" % ",".join(map(fmt, self.to_array()))
+
+class SWFGradientRecord(_dumb_repr):
+    def __init__(self, data=None, level=1):
+        self._records = []
+        if not data is None:
+            self.parse(data, level)
+
+    def parse(self, data, level=1):
+        self.ratio = data.readUI8()
+        self.color = data.readRGB() if level <= 2 else data.readRGBA()
+
+    def __str__(self):
+        return "[SWFGradientRecord] Color: %s, Ratio: %d" % (ColorUtils.to_rgb_string(self.color), self.ratio)
+
+class SWFGradient(_dumb_repr):
+    def __init__(self, data=None, level=1):
+        self._records = []
+        self.focal_point = 0.0
+        if not data is None:
+            self.parse(data, level)
+
+    @property
+    def records(self):
+        return self._records
+
+    def parse(self, data, level=1):
+        data.reset_bits_pending();
+        self.spreadmethod = data.readUB(2)
+        self.interpolation_mode = data.readUB(2)
+        num_gradients = data.readUB(4)
+        for i in range(0, num_gradients):
+            self._records.append(data.readGRADIENTRECORD(level))
+
+    def __str__(self):
+        s = "[SWFGadient]"
+        for record in self._records:
+            s += "\n  " + record.__str__()
+        return s
+
+class SWFFocalGradient(SWFGradient):
+    def __init__(self, data=None, level=1):
+        super(SWFFocalGradient, self).__init__(data, level)
+
+    def parse(self, data, level=1):
+        super(SWFFocalGradient, self).parse(data, level)
+        self.focal_point = data.readFIXED8()
+
+    def __str__(self):
+        return "[SWFFocalGradient] Color: %s, Ratio: %d, Focal: %0.2f" % \
+            (ColorUtils.to_rgb_string(self.color), self.ratio, self.focal_point)
+
+class SWFFillStyle(_dumb_repr):
+    def __init__(self, data=None, level=1):
+        if not data is None:
+            self.parse(data, level)
+
+    COLOR = [0x0]
+    GRADIENT = [0x10, 0x12, 0x13]
+    BITMAP = [0x40, 0x41, 0x42, 0x43]
+
+    def parse(self, data, level=1):
+        self.type = data.readUI8()
+        if self.type in SWFFillStyle.COLOR:
+            self.rgb = data.readRGB() if level <= 2 else data.readRGBA()
+        elif self.type in SWFFillStyle.GRADIENT:
+            self.gradient_matrix = data.readMATRIX()
+            self.gradient = data.readFOCALGRADIENT(level) if self.type == 0x13 else data.readGRADIENT(level)
+        elif self.type in SWFFillStyle.BITMAP:
+            self.bitmap_id = data.readUI16()
+            self.bitmap_matrix = data.readMATRIX()
+        else:
+            raise Exception("Unknown fill style type: 0x%x" % self.type, level)
+
+    def get_dependencies(self):
+        return set([self.bitmap_id]) if self.type in SWFFillStyle.BITMAP else set()
+
+    def __str__(self):
+        s = "[SWFFillStyle] "
+        if self.type in SWFFillStyle.COLOR:
+            s += "Color: %s" % ColorUtils.to_rgb_string(self.rgb)
+        elif self.type in SWFFillStyle.GRADIENT:
+            s += "Gradient: %s" % self.gradient_matrix
+        elif self.type in SWFFillStyle.BITMAP:
+            s += "BitmapID: %d" % (self.bitmap_id)
+        return s
+
+class SWFLineStyle(_dumb_repr):
+    def __init__(self, data=None, level=1):
+        # forward declarations for SWFLineStyle2
+        self.start_caps_style = LineCapsStyle.ROUND
+        self.end_caps_style = LineCapsStyle.ROUND
+        self.joint_style = LineJointStyle.ROUND
+        self.has_fill_flag = False
+        self.no_hscale_flag = False
+        self.no_vscale_flag = False
+        self.pixelhinting_flag = False
+        self.no_close = False
+        self.miter_limit_factor = 3.0
+        self.fill_type = None
+        self.width = 1
+        self.color = 0
+        if not data is None:
+            self.parse(data, level)
+
+    def get_dependencies(self):
+        return set()
+
+    def parse(self, data, level=1):
+        self.width = data.readUI16()
+        self.color = data.readRGB() if level <= 2 else data.readRGBA()
+
+    def __str__(self):
+        s = "[SWFLineStyle] "
+        s += "Color: %s, Width: %d" % (ColorUtils.to_rgb_string(self.color), self.width)
+        return s
+
+class SWFLineStyle2(SWFLineStyle):
+    def __init__(self, data=None, level=1):
+        super(SWFLineStyle2, self).__init__(data, level)
+
+    def parse(self, data, level=1):
+        self.width = data.readUI16()
+        self.start_caps_style = data.readUB(2)
+        self.joint_style = data.readUB(2)
+        self.has_fill_flag = (data.readUB(1) == 1)
+        self.no_hscale_flag = (data.readUB(1) == 1)
+        self.no_vscale_flag = (data.readUB(1) == 1)
+        self.pixelhinting_flag = (data.readUB(1) == 1)
+        data.readUB(5)
+        self.no_close = (data.readUB(1) == 1)
+        self.end_caps_style = data.readUB(2)
+        if self.joint_style == LineJointStyle.MITER:
+            self.miter_limit_factor = data.readFIXED8()
+        if self.has_fill_flag:
+            self.fill_type = data.readFILLSTYLE(level)
+        else:
+            self.color = data.readRGBA()
+
+    def __str__(self):
+        s = "[SWFLineStyle2] "
+        s += "Width: %d, " % self.width
+        s += "StartCapsStyle: %d, " % self.start_caps_style
+        s += "JointStyle: %d, " % self.joint_style
+        s += "HasFillFlag: %d, " % self.has_fill_flag
+        s += "NoHscaleFlag: %d, " % self.no_hscale_flag
+        s += "NoVscaleFlag: %d, " % self.no_vscale_flag
+        s += "PixelhintingFlag: %d, " % self.pixelhinting_flag
+        s += "NoClose: %d, " % self.no_close
+
+        if self.joint_style:
+            s += "MiterLimitFactor: %d" % self.miter_limit_factor
+        if self.has_fill_flag:
+            s += "FillType: %s, " % self.fill_type
+        else:
+            s += "Color: %s" % ColorUtils.to_rgb_string(self.color)
+
+        return s
+
+class SWFMorphGradientRecord(_dumb_repr):
+    def __init__(self, data):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        self.startRatio = data.readUI8()
+        self.startColor = data.readRGBA()
+        self.endRatio = data.readUI8()
+        self.endColor = data.readRGBA()
+
+class SWFMorphGradient(_dumb_repr):
+    def __init__(self, data, level=1):
+        self.records = []
+        if not data is None:
+            self.parse(data, level)
+
+    def parse(self, data, level=1):
+        self.records = []
+        numGradients = data.readUI8()
+        for i in range(0, numGradients):
+            self.records.append(data.readMORPHGRADIENTRECORD())
+
+class SWFMorphFillStyle(_dumb_repr):
+    def __init__(self, data, level=1):
+        if not data is None:
+            self.parse(data, level)
+
+    def get_dependencies(self):
+        return set([self.bitmapId]) if hasattr(self, 'bitmapId') else set()
+
+    def parse(self, data, level=1):
+        type = data.readUI8()
+        if type == 0x0:
+            self.startColor = data.readRGBA()
+            self.endColor = data.readRGBA()
+        elif type in [0x10, 0x12]:
+            self.startGradientMatrix = data.readMATRIX()
+            self.endGradientMatrix = data.readMATRIX()
+            self.gradient = data.readMORPHGRADIENT(level)
+        elif type in [0x40, 0x41, 0x42, 0x43]:
+            self.bitmapId = data.readUI16()
+            self.startBitmapMatrix = data.readMATRIX()
+            self.endBitmapMatrix = data.readMATRIX()
+
+class SWFMorphLineStyle(_dumb_repr):
+    def __init__(self, data, level=1):
+        # Forward declaration of SWFMorphLineStyle2 properties
+        self.startCapsStyle = LineCapsStyle.ROUND
+        self.endCapsStyle = LineCapsStyle.ROUND
+        self.jointStyle = LineJointStyle.ROUND
+        self.hasFillFlag = False
+        self.noHScaleFlag = False
+        self.noVScaleFlag = False
+        self.pixelHintingFlag = False
+        self.noClose = False
+        self.miterLimitFactor = 3
+        self.fillType = None
+        if not data is None:
+            self.parse(data, level)
+
+    def parse(self, data, level=1):
+        self.startWidth = data.readUI16()
+        self.endWidth = data.readUI16()
+        self.startColor = data.readRGBA()
+        self.endColor = data.readRGBA()
+
+class SWFMorphLineStyle2(SWFMorphLineStyle):
+    def __init__(self, data, level=1):
+        super(SWFMorphLineStyle2, self).__init__(data, level)
+
+    def parse(self, data, level=1):
+        self.startWidth = data.readUI16()
+        self.endWidth = data.readUI16()
+        self.startCapsStyle = data.readUB(2)
+        self.jointStyle = data.readUB(2)
+        self.hasFillFlag = (data.readUB(1) == 1)
+        self.noHScaleFlag = (data.readUB(1) == 1)
+        self.noVScaleFlag = (data.readUB(1) == 1)
+        self.pixelHintingFlag = (data.readUB(1) == 1)
+        reserved = data.readUB(5);
+        self.noClose = (data.readUB(1) == 1)
+        self.endCapsStyle = data.readUB(2)
+        if self.jointStyle == LineJointStyle.MITER:
+            self.miterLimitFactor = data.readFIXED8()
+        if self.hasFillFlag:
+            self.fillType = data.readMORPHFILLSTYLE(level)
+        else:
+            self.startColor = data.readRGBA()
+            self.endColor = data.readRGBA()
+
+class SWFRecordHeader(_dumb_repr):
+    def __init__(self, type, content_length, header_length):
+        self.type = type
+        self.content_length = content_length
+        self.header_length = header_length
+
+    @property
+    def tag_length(self):
+        return self.header_length + self.content_length
+
+class SWFRectangle(_dumb_repr):
+    def __init__(self):
+        self.xmin = self.xmax = self.ymin = self.ymax = 0
+
+    def parse(self, s):
+        s.reset_bits_pending()
+        bits = s.readUB(5)
+        self.xmin = s.readSB(bits)
+        self.xmax = s.readSB(bits)
+        self.ymin = s.readSB(bits)
+        self.ymax = s.readSB(bits)
+
+    @property
+    def dimensions(self):
+        """
+        Returns dimensions as (x, y) tuple.
+        """
+        return (self.xmax - self.xmin, self.ymax - self.ymin)
+
+    def __str__(self):
+        return "[xmin: %d xmax: %d ymin: %d ymax: %d]" % (self.xmin/20, self.xmax/20, self.ymin/20, self.ymax/20)
+
+class SWFColorTransform(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        data.reset_bits_pending()
+        self.hasAddTerms = (data.readUB(1) == 1)
+        self.hasMultTerms = (data.readUB(1) == 1)
+        bits = data.readUB(4)
+        self.rMult = 1
+        self.gMult = 1
+        self.bMult = 1
+        if self.hasMultTerms:
+            self.rMult = data.readSB(bits)
+            self.gMult = data.readSB(bits)
+            self.bMult = data.readSB(bits)
+        self.rAdd = 0
+        self.gAdd = 0
+        self.bAdd = 0
+        if self.hasAddTerms:
+            self.rAdd = data.readSB(bits)
+            self.gAdd = data.readSB(bits)
+            self.bAdd = data.readSB(bits)
+
+    @property
+    def matrix(self):
+        return [
+            self.rMult / 256.0, 0.0, 0.0, 0.0, self.rAdd / 256.0,
+            0.0, self.gMult / 256.0, 0.0, 0.0, self.gAdd / 256.0,
+            0.0, 0.0, self.bMult / 256.0, 0.0, self.bAdd / 256.0,
+            0.0, 0.0, 0.0, 1.0, 1.0
+        ]
+
+    def __str__(self):
+        return "[%d %d %d %d %d %d]" % \
+            (self.rMult, self.gMult, self.bMult, self.rAdd, self.gAdd, self.bAdd)
+
+class SWFColorTransformWithAlpha(SWFColorTransform):
+    def __init__(self, data=None):
+        super(SWFColorTransformWithAlpha, self).__init__(data)
+
+    def parse(self, data):
+        data.reset_bits_pending()
+        self.hasAddTerms = (data.readUB(1) == 1)
+        self.hasMultTerms = (data.readUB(1) == 1)
+        bits = data.readUB(4)
+        self.rMult = 1
+        self.gMult = 1
+        self.bMult = 1
+        self.aMult = 1
+        if self.hasMultTerms:
+            self.rMult = data.readSB(bits)
+            self.gMult = data.readSB(bits)
+            self.bMult = data.readSB(bits)
+            self.aMult = data.readSB(bits)
+        self.rAdd = 0
+        self.gAdd = 0
+        self.bAdd = 0
+        self.aAdd = 0
+        if self.hasAddTerms:
+            self.rAdd = data.readSB(bits)
+            self.gAdd = data.readSB(bits)
+            self.bAdd = data.readSB(bits)
+            self.aAdd = data.readSB(bits)
+
+    @property
+    def matrix(self):
+        '''
+        Gets the matrix as a 20 item list
+        '''
+        return [
+            self.rMult / 256.0, 0.0, 0.0, 0.0, self.rAdd / 256.0,
+            0.0, self.gMult / 256.0, 0.0, 0.0, self.gAdd / 256.0,
+            0.0, 0.0, self.bMult / 256.0, 0.0, self.bAdd / 256.0,
+            0.0, 0.0, 0.0, self.aMult / 256.0, self.aAdd / 256.0
+        ]
+
+    def __str__(self):
+        return "[%d %d %d %d %d %d %d %d]" % \
+            (self.rMult, self.gMult, self.bMult, self.aMult, self.rAdd, self.gAdd, self.bAdd, self.aAdd)
+
+class SWFFrameLabel(_dumb_repr):
+    def __init__(self, frameNumber, name):
+        self.frameNumber = frameNumber
+        self.name = name
+
+    def __str__(self):
+        return "Frame: %d, Name: %s" % (self.frameNumber, self.name)
+
+class SWFScene(_dumb_repr):
+    def __init__(self, offset, name):
+        self.offset = offset
+        self.name = name
+
+    def __str__(self):
+        return "Scene: %d, Name: '%s'" % (self.offset, self.name)
+
+class SWFSymbol(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        self.tagId = data.readUI16()
+        self.name = data.readString()
+
+    def __str__(self):
+        return "ID %d, Name: %s" % (self.tagId, self.name)
+
+class SWFGlyphEntry(_dumb_repr):
+    def __init__(self, data=None, glyphBits=0, advanceBits=0):
+        if not data is None:
+            self.parse(data, glyphBits, advanceBits)
+
+    def parse(self, data, glyphBits, advanceBits):
+        # GLYPHENTRYs are not byte aligned
+        self.index = data.readUB(glyphBits)
+        self.advance = data.readSB(advanceBits)
+
+    def __str__(self):
+        return "Index: %d, Advance: %d" % (self.index, self.advance)
+
+class SWFKerningRecord(_dumb_repr):
+    def __init__(self, data=None, wideCodes=False):
+        if not data is None:
+            self.parse(data, wideCodes)
+
+    def parse(self, data, wideCodes):
+        self.code1 = data.readUI16() if wideCodes else data.readUI8()
+        self.code2 = data.readUI16() if wideCodes else data.readUI8()
+        self.adjustment = data.readSI16()
+
+    def __str__(self):
+        return "Code1: %d, Code2: %d, Adjustment: %d" % (self.code1, self.code2, self.adjustment)
+
+class SWFTextRecord(_dumb_repr):
+    def __init__(self, data=None, glyphBits=0, advanceBits=0, previousRecord=None, level=1):
+        self.hasFont = False
+        self.hasColor = False
+        self.hasYOffset = False
+        self.hasXOffset = False
+        self.fontId = -1
+        self.textColor = 0
+        self.xOffset = 0
+        self.yOffset = 0
+        self.textHeight = 12
+        self.glyphEntries = []
+        if not data is None:
+            self.parse(data, glyphBits, advanceBits, previousRecord, level)
+
+    def get_dependencies(self):
+        return set([self.fontId]) if self.hasFont else set()
+
+    def parse(self, data, glyphBits, advanceBits, previousRecord=None, level=1):
+        self.glyphEntries = []
+        styles = data.readUI8()
+        self.type = styles >> 7
+        self.hasFont = ((styles & 0x08) != 0)
+        self.hasColor = ((styles & 0x04) != 0)
+        self.hasYOffset = ((styles & 0x02) != 0)
+        self.hasXOffset = ((styles & 0x01) != 0)
+
+        if self.hasFont:
+            self.fontId = data.readUI16()
+        elif not previousRecord is None:
+            self.fontId = previousRecord.fontId
+
+        if self.hasColor:
+            self.textColor = data.readRGB() if level < 2 else data.readRGBA()
+        elif not previousRecord is None:
+            self.textColor = previousRecord.textColor
+
+        if self.hasXOffset:
+            self.xOffset = data.readSI16();
+        elif not previousRecord is None:
+            self.xOffset = previousRecord.xOffset
+
+        if self.hasYOffset:
+            self.yOffset = data.readSI16();
+        elif not previousRecord is None:
+            self.yOffset = previousRecord.yOffset
+
+        if self.hasFont:
+            self.textHeight = data.readUI16()
+        elif not previousRecord is None:
+            self.textHeight = previousRecord.textHeight
+
+        glyphCount = data.readUI8()
+        for i in range(0, glyphCount):
+            self.glyphEntries.append(data.readGLYPHENTRY(glyphBits, advanceBits))
+
+    def __str__(self):
+        return "[SWFTextRecord]"
+
+class SWFClipActions(_dumb_repr):
+    def __init__(self, data=None, version=0):
+        self.eventFlags = None
+        self.records = []
+        if not data is None:
+            self.parse(data, version)
+
+    def parse(self, data, version):
+        data.readUI16() # reserved, always 0
+        self.eventFlags = data.readCLIPEVENTFLAGS(version)
+        self.records = []
+        record = data.readCLIPACTIONRECORD(version)
+        while not record is None:
+            self.records.append(record)
+            record = data.readCLIPACTIONRECORD(version)
+
+    def __str__(self):
+        return "[SWFClipActions]"
+
+class SWFClipActionRecord(_dumb_repr):
+    def __init__(self, data=None, version=0):
+        self.eventFlags = None
+        self.keyCode = 0
+        self.actions = []
+        if not data is None:
+            self.parse(data, version)
+
+    def parse(self, data, version):
+        self.actions = []
+        self.eventFlags = data.readCLIPEVENTFLAGS(version)
+        data.readUI32() # actionRecordSize, not needed here
+        if self.eventFlags.keyPressEvent:
+            self.keyCode = data.readUI8()
+        action = data.readACTIONRECORD()
+        while not action is None:
+            self.actions.append(action)
+            action = data.readACTIONRECORD()
+
+    def __str__(self):
+        return "[SWFClipActionRecord]"
+
+class SWFClipEventFlags(_dumb_repr):
+    keyUpEvent = False
+    keyDownEvent = False
+    mouseUpEvent = False
+    mouseDownEvent = False
+    mouseMoveEvent = False
+    unloadEvent = False
+    enterFrameEvent = False
+    loadEvent = False
+    dragOverEvent = False # SWF6
+    rollOutEvent = False # SWF6
+    rollOverEvent = False # SWF6
+    releaseOutsideEvent = False # SWF6
+    releaseEvent = False # SWF6
+    pressEvent = False # SWF6
+    initializeEvent = False # SWF6
+    dataEvent = False
+    constructEvent = False # SWF7
+    keyPressEvent = False # SWF6
+    dragOutEvent = False # SWF6
+
+    def __init__(self, data=None, version=0):
+        if not data is None:
+            self.parse(data, version)
+
+    def parse(self, data, version):
+        flags1 = data.readUI8();
+        self.keyUpEvent = ((flags1 & 0x80) != 0)
+        self.keyDownEvent = ((flags1 & 0x40) != 0)
+        self.mouseUpEvent = ((flags1 & 0x20) != 0)
+        self.mouseDownEvent = ((flags1 & 0x10) != 0)
+        self.mouseMoveEvent = ((flags1 & 0x08) != 0)
+        self.unloadEvent = ((flags1 & 0x04) != 0)
+        self.enterFrameEvent = ((flags1 & 0x02) != 0)
+        self.loadEvent = ((flags1 & 0x01) != 0)
+        flags2 = data.readUI8()
+        self.dragOverEvent = ((flags2 & 0x80) != 0)
+        self.rollOutEvent = ((flags2 & 0x40) != 0)
+        self.rollOverEvent = ((flags2 & 0x20) != 0)
+        self.releaseOutsideEvent = ((flags2 & 0x10) != 0)
+        self.releaseEvent = ((flags2 & 0x08) != 0)
+        self.pressEvent = ((flags2 & 0x04) != 0)
+        self.initializeEvent = ((flags2 & 0x02) != 0)
+        self.dataEvent = ((flags2 & 0x01) != 0)
+        if version >= 6:
+            flags3 = data.readUI8()
+            self.constructEvent = ((flags3 & 0x04) != 0)
+            self.keyPressEvent = ((flags3 & 0x02) != 0)
+            self.dragOutEvent = ((flags3 & 0x01) != 0)
+            data.readUI8() # reserved, always 0
+
+    def __str__(self):
+        return "[SWFClipEventFlags]"
+
+class SWFZoneData(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        self.alignmentCoordinate = data.readFLOAT16()
+        self.zoneRange = data.readFLOAT16()
+
+    def __str__(self):
+        return "[SWFZoneData]"
+
+class SWFZoneRecord(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        self.zoneData = []
+        numZoneData = data.readUI8()
+        for i in range(0, numZoneData):
+            self.zoneData.append(data.readZONEDATA())
+        mask = data.readUI8()
+        self.maskX = ((mask & 0x01) != 0)
+        self.maskY = ((mask & 0x02) != 0)
+
+    def __str__(self):
+        return "[SWFZoneRecord]"
+
+class SWFSoundInfo(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        reserved = data.readUB(2)
+        assert reserved == 0
+        self.syncStop = data.readUB(1) == 1
+        self.syncNoMultiple = data.readUB(1) == 1
+        self.hasEnvelope = data.readUB(1) == 1
+        self.hasLoops = data.readUB(1) == 1
+        self.hasOutPoint = data.readUB(1) == 1
+        self.hasInPoint = data.readUB(1) == 1
+        self.inPoint = data.readUI32() if self.hasInPoint else None
+        self.outPoint = data.readUI32() if self.hasOutPoint else None
+        self.loopCount = data.readUI16() if self.hasLoops else None
+        self.envPointCount = data.readUI8() if self.hasEnvelope else None
+        self.envelopePoints = [data.readSOUNDENVELOPE() for x in xrange(self.envPointCount)] if self.hasEnvelope else None
+
+    def __str__(self):
+        return "[SWFSoundInfo]"
+
+class SWFSoundEnvelope(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        self.position = data.readUI32()
+        self.leftLevel = data.readUI16()
+        self.rightLevel = data.readUI16()
+
+    def __str__(self):
+        return "[SWFSoundEnvelope]"
+
+class SWFButtonRecord(_dumb_repr):
+    def __init__(self, version, data=None):
+        # version is 1 for DefineButton, 2 for DefineButton2, etc
+        if not data is None:
+            self.parse(data, version)
+
+    def get_dependencies(self):
+        return set([self.characterId]) if self.valid else set()
+
+    def parse(self, data, version):
+        reserved0 = data.readUB(2)
+        self.hasBlendMode = data.readUB(1) == 1
+        self.hasFilterList = data.readUB(1) == 1
+        self.stateHitTest = data.readUB(1) == 1
+        self.stateDown = data.readUB(1) == 1
+        self.stateOver = data.readUB(1) == 1
+        self.stateUp = data.readUB(1) == 1
+
+        self.valid = reserved0 or self.hasBlendMode or \
+                     self.hasFilterList or self.stateHitTest or \
+                     self.stateDown or self.stateOver or self.stateUp
+        if not self.valid:
+            return
+
+        self.characterId = data.readUI16()
+        self.placeDepth = data.readUI16()
+        self.placeMatrix = data.readMATRIX()
+
+        if version == 2:
+            self.colorTransform = data.readCXFORMWITHALPHA()
+            self.filterList = data.readFILTERLIST() if self.hasFilterList else None
+            self.blendMode = data.readUI8() if self.hasBlendMode else 0
+
+    def __str__(self):
+        return "[SWFButtonRecord]"
+
+    def __repr__(self):
+        return "[SWFButtonRecord %r]" % self.__dict__
+
+class SWFButtonCondAction(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def parse(self, data):
+        self.idleToOverDown = data.readUB(1) == 1
+        self.outDownToIdle = data.readUB(1) == 1
+        self.outDownToOverDown = data.readUB(1) == 1
+        self.overDownToOutDown = data.readUB(1) == 1
+
+        self.overDownToOverUp = data.readUB(1) == 1
+        self.overUpToOverDown = data.readUB(1) == 1
+        self.overUpToIdle = data.readUB(1) == 1
+        self.idleToOverUp = data.readUB(1) == 1
+
+        self.keyPress = data.readUB(7)
+        self.overDownToIdle = data.readUB(1) == 1
+
+        self.actions = data.readACTIONRECORDs()
+
+    def __str__(self):
+        return "[SWFButtonCondAction]"
+
+class SWFExport(_dumb_repr):
+    def __init__(self, data=None):
+        if not data is None:
+            self.parse(data)
+
+    def get_dependencies(self):
+        return set([self.characterId])
+
+    def parse(self, data):
+        self.characterId = data.readUI16()
+        self.characterName = data.readString()
+
+    def __str__(self):
+        return "[SWFExport %d as %r]" % (self.characterId, self.characterName)

+ 1065 - 0
format_convert/swf/export.py

@@ -0,0 +1,1065 @@
+"""
+This module defines exporters for the SWF fileformat.
+"""
+from .consts import *
+from .geom import *
+from .utils import *
+from .data import *
+from .tag import *
+from .filters import *
+from lxml import objectify
+from lxml import etree
+import base64
+try:
+    import Image
+except ImportError:
+    from PIL import Image
+try:
+    from cBytesIO import BytesIO
+except ImportError:
+    from io import BytesIO
+import math
+import re
+import copy
+import cgi
+
+SVG_VERSION = "1.1"
+SVG_NS      = "http://www.w3.org/2000/svg"
+XLINK_NS    = "http://www.w3.org/1999/xlink"
+XLINK_HREF  = "{%s}href" % XLINK_NS
+NS = {"svg" : SVG_NS, "xlink" : XLINK_NS}
+
+PIXELS_PER_TWIP = 20
+EM_SQUARE_LENGTH = 1024
+
+MINIMUM_STROKE_WIDTH = 0.5
+
+CAPS_STYLE = {
+    0 : 'round',
+    1 : 'butt',
+    2 : 'square'
+}
+
+JOIN_STYLE = {
+    0 : 'round',
+    1 : 'bevel',
+    2 : 'miter'
+}
+
+class DefaultShapeExporter(object):
+    """
+    The default (abstract) Shape exporter class.
+    All shape exporters should extend this class.
+
+
+    """
+    def __init__(self, swf=None, debug=False, force_stroke=False):
+        self.swf = None
+        self.debug = debug
+        self.force_stroke = force_stroke
+
+    def begin_bitmap_fill(self, bitmap_id, matrix=None, repeat=False, smooth=False):
+        pass
+    def begin_fill(self, color, alpha=1.0):
+        pass
+    def begin_gradient_fill(self, type, colors, alphas, ratios,
+                            matrix=None,
+                            spreadMethod=SpreadMethod.PAD,
+                            interpolationMethod=InterpolationMethod.RGB,
+                            focalPointRatio=0.0):
+        pass
+    def line_style(self,
+                    thickness=float('nan'), color=0, alpha=1.0,
+                    pixelHinting=False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=None, endCaps=None,
+                    joints=None, miterLimit=3.0):
+        pass
+    def line_gradient_style(self,
+                    thickness=float('nan'), color=0, alpha=1.0,
+                    pixelHinting=False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=None, endCaps=None,
+                    joints=None, miterLimit=3.0,
+                    type = 1, colors = [], alphas = [], ratios = [],
+                    matrix=None,
+                    spreadMethod=SpreadMethod.PAD,
+                    interpolationMethod=InterpolationMethod.RGB,
+                    focalPointRatio=0.0):
+        pass
+    def line_bitmap_style(self,
+                    thickness=float('nan'),
+                    pixelHinting=False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=None, endCaps=None,
+                    joints=None, miterLimit = 3.0,
+                    bitmap_id=None, matrix=None, repeat=False, smooth=False):
+        pass
+    def end_fill(self):
+        pass
+
+    def begin_fills(self):
+        pass
+    def end_fills(self):
+        pass
+    def begin_lines(self):
+        pass
+    def end_lines(self):
+        pass
+
+    def begin_shape(self):
+        pass
+    def end_shape(self):
+        pass
+
+    def move_to(self, x, y):
+        #print "move_to", x, y
+        pass
+    def line_to(self, x, y):
+        #print "line_to", x, y
+        pass
+    def curve_to(self, cx, cy, ax, ay):
+        #print "curve_to", cx, cy, ax, ay
+        pass
+
+class DefaultSVGShapeExporter(DefaultShapeExporter):
+    def __init__(self, defs=None):
+        self.defs = defs
+        self.current_draw_command = ""
+        self.path_data = ""
+        self._e = objectify.ElementMaker(annotate=False,
+                        namespace=SVG_NS, nsmap={None : SVG_NS, "xlink" : XLINK_NS})
+        super(DefaultSVGShapeExporter, self).__init__()
+
+    def move_to(self, x, y):
+        self.current_draw_command = ""
+        self.path_data += "M" + \
+            str(NumberUtils.round_pixels_20(x)) + " " + \
+            str(NumberUtils.round_pixels_20(y)) + " "
+
+    def line_to(self, x, y):
+        if self.current_draw_command != "L":
+            self.current_draw_command = "L"
+            self.path_data += "L"
+        self.path_data += "" + \
+            str(NumberUtils.round_pixels_20(x)) + " " + \
+            str(NumberUtils.round_pixels_20(y)) + " "
+
+    def curve_to(self, cx, cy, ax, ay):
+        if self.current_draw_command != "Q":
+            self.current_draw_command = "Q"
+            self.path_data += "Q"
+        self.path_data += "" + \
+            str(NumberUtils.round_pixels_20(cx)) + " " + \
+            str(NumberUtils.round_pixels_20(cy)) + " " + \
+            str(NumberUtils.round_pixels_20(ax)) + " " + \
+            str(NumberUtils.round_pixels_20(ay)) + " "
+
+    def begin_bitmap_fill(self, bitmap_id, matrix=None, repeat=False, smooth=False):
+        self.finalize_path()
+
+    def begin_fill(self, color, alpha=1.0):
+        self.finalize_path()
+
+    def end_fill(self):
+        pass
+        #self.finalize_path()
+
+    def begin_fills(self):
+        pass
+    def end_fills(self):
+        self.finalize_path()
+
+    def begin_gradient_fill(self, type, colors, alphas, ratios,
+                            matrix=None,
+                            spreadMethod=SpreadMethod.PAD,
+                            interpolationMethod=InterpolationMethod.RGB,
+                            focalPointRatio=0.0):
+        self.finalize_path()
+
+    def line_style(self,
+                    thickness=float('nan'), color=0, alpha=1.0,
+                    pixelHinting=False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=None, endCaps=None,
+                    joints=None, miterLimit=3.0):
+        self.finalize_path()
+
+    def end_lines(self):
+        self.finalize_path()
+
+    def end_shape(self):
+        self.finalize_path()
+
+    def finalize_path(self):
+        self.current_draw_command = ""
+        self.path_data = ""
+
+class SVGShapeExporter(DefaultSVGShapeExporter):
+    def __init__(self):
+        self.path = None
+        self.num_patterns = 0
+        self.num_gradients = 0
+        self._gradients = {}
+        self._gradient_ids = {}
+        self.paths = {}
+        self.fills_ended = False
+        super(SVGShapeExporter, self).__init__()
+
+    def begin_shape(self):
+        self.g = self._e.g()
+
+    def begin_fill(self, color, alpha=1.0):
+        self.finalize_path()
+        self.path.set("fill", ColorUtils.to_rgb_string(color))
+        if alpha < 1.0:
+            self.path.set("fill-opacity", str(alpha))
+        elif self.force_stroke:
+            self.path.set("stroke", ColorUtils.to_rgb_string(color))
+            self.path.set("stroke-width", "1")
+        else:
+            self.path.set("stroke", "none")
+
+    def begin_gradient_fill(self, type, colors, alphas, ratios,
+                            matrix=None,
+                            spreadMethod=SpreadMethod.PAD,
+                            interpolationMethod=InterpolationMethod.RGB,
+                            focalPointRatio=0.0):
+        self.finalize_path()
+        gradient_id = self.export_gradient(type, colors, alphas, ratios, matrix, spreadMethod, interpolationMethod, focalPointRatio)
+        self.path.set("stroke", "none")
+        self.path.set("fill", "url(#%s)" % gradient_id)
+
+    def export_gradient(self, type, colors, alphas, ratios,
+                        matrix=None,
+                        spreadMethod=SpreadMethod.PAD,
+                        interpolationMethod=InterpolationMethod.RGB,
+                        focalPointRatio=0.0):
+        self.num_gradients += 1
+        gradient_id = "gradient%d" % self.num_gradients
+        gradient = self._e.linearGradient() if type == GradientType.LINEAR \
+            else self._e.radialGradient()
+        gradient.set("gradientUnits", "userSpaceOnUse")
+
+        if type == GradientType.LINEAR:
+            gradient.set("x1", "-819.2")
+            gradient.set("x2", "819.2")
+        else:
+            gradient.set("r", "819.2")
+            gradient.set("cx", "0")
+            gradient.set("cy", "0")
+            if focalPointRatio < 0.0 or focalPointRatio > 0.0:
+                gradient.set("fx", str(819.2 * focalPointRatio))
+                gradient.set("fy", "0")
+
+        if spreadMethod == SpreadMethod.PAD:
+            gradient.set("spreadMethod", "pad")
+        elif spreadMethod == SpreadMethod.REFLECT:
+            gradient.set("spreadMethod", "reflect")
+        elif spreadMethod == SpreadMethod.REPEAT:
+            gradient.set("spreadMethod", "repeat")
+
+        if interpolationMethod == InterpolationMethod.LINEAR_RGB:
+            gradient.set("color-interpolation", "linearRGB")
+
+        if matrix is not None:
+            sm = _swf_matrix_to_svg_matrix(matrix)
+            gradient.set("gradientTransform", sm);
+
+        for i in range(0, len(colors)):
+            entry = self._e.stop()
+            offset = ratios[i] / 255.0
+            entry.set("offset", str(offset))
+            if colors[i] != 0.0:
+                entry.set("stop-color", ColorUtils.to_rgb_string(colors[i]))
+            if alphas[i] != 1.0:
+                entry.set("stop-opacity", str(alphas[i]))
+            gradient.append(entry)
+
+        # prevent same gradient in <defs />
+        key = etree.tostring(gradient)
+        if key in self._gradients:
+            gradient_id = self._gradient_ids[key]
+        else:
+            self._gradients[key] = copy.copy(gradient)
+            self._gradient_ids[key] = gradient_id
+            gradient.set("id", gradient_id)
+            self.defs.append(gradient)
+
+        return gradient_id
+
+    def export_pattern(self, bitmap_id, matrix, repeat=False, smooth=False):
+        self.num_patterns += 1
+        bitmap_id = "c%d" % bitmap_id
+        e = self.defs.xpath("./svg:image[@id='%s']" % bitmap_id, namespaces=NS)
+        if len(e) < 1:
+            raise Exception("SVGShapeExporter::begin_bitmap_fill Could not find bitmap!")
+        image = e[0]
+        pattern_id = "pat%d" % (self.num_patterns)
+        pattern = self._e.pattern()
+        pattern.set("id", pattern_id)
+        pattern.set("width", image.get("width"))
+        pattern.set("height", image.get("height"))
+        pattern.set("patternUnits", "userSpaceOnUse")
+        #pattern.set("patternContentUnits", "objectBoundingBox")
+        if matrix is not None:
+            pattern.set("patternTransform", _swf_matrix_to_svg_matrix(matrix, True, True, True))
+            pass
+        use = self._e.use()
+        use.set(XLINK_HREF, "#%s" % bitmap_id)
+        pattern.append(use)
+        self.defs.append(pattern)
+
+        return pattern_id
+
+    def begin_bitmap_fill(self, bitmap_id, matrix=None, repeat=False, smooth=False):
+        self.finalize_path()
+        pattern_id = self.export_pattern(bitmap_id, matrix, repeat, smooth)
+        self.path.set("stroke", "none")
+        self.path.set("fill", "url(#%s)" % pattern_id)
+
+    def line_style(self,
+                    thickness=float('nan'), color=0, alpha=1.0,
+                    pixelHinting=False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=None, endCaps=None,
+                    joints=None, miterLimit=3.0):
+        self.finalize_path()
+        self.path.set("fill", "none")
+        self.path.set("stroke", ColorUtils.to_rgb_string(color))
+        thickness = 1 if math.isnan(thickness) else thickness
+        thickness = MINIMUM_STROKE_WIDTH if thickness < MINIMUM_STROKE_WIDTH else thickness
+        self.path.set("stroke-width", str(thickness))
+        if alpha < 1.0:
+            self.path.set("stroke-opacity", str(alpha))
+
+    def line_gradient_style(self,
+                    thickness=float('nan'),
+                    pixelHinting = False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=0, endCaps=0,
+                    joints=0, miterLimit=3.0,
+                    type = 1,
+                    colors = [],
+                    alphas = [],
+                    ratios = [],
+                    matrix=None,
+                    spreadMethod=SpreadMethod.PAD,
+                    interpolationMethod=InterpolationMethod.RGB,
+                    focalPointRatio=0.0):
+        self.finalize_path()
+        gradient_id = self.export_gradient(type, colors, alphas, ratios, matrix, spreadMethod, interpolationMethod, focalPointRatio)
+        self.path.set("fill", "none")
+        self.path.set("stroke-linejoin", JOIN_STYLE[joints])
+        self.path.set("stroke-linecap", CAPS_STYLE[startCaps])
+        self.path.set("stroke", "url(#%s)" % gradient_id)
+        thickness = 1 if math.isnan(thickness) else thickness
+        thickness = MINIMUM_STROKE_WIDTH if thickness < MINIMUM_STROKE_WIDTH else thickness
+        self.path.set("stroke-width", str(thickness))
+
+    def line_bitmap_style(self,
+                    thickness=float('nan'),
+                    pixelHinting=False,
+                    scaleMode=LineScaleMode.NORMAL,
+                    startCaps=None, endCaps=None,
+                    joints=None, miterLimit = 3.0,
+                    bitmap_id=None, matrix=None, repeat=False, smooth=False):
+        self.finalize_path()
+        pattern_id = self.export_pattern(bitmap_id, matrix, repeat, smooth)
+        self.path.set("fill", "none")
+        self.path.set("stroke", "url(#%s)" % pattern_id)
+        self.path.set("stroke-linejoin", JOIN_STYLE[joints])
+        self.path.set("stroke-linecap", CAPS_STYLE[startCaps])
+        thickness = 1 if math.isnan(thickness) else thickness
+        thickness = MINIMUM_STROKE_WIDTH if thickness < MINIMUM_STROKE_WIDTH else thickness
+        self.path.set("stroke-width", str(thickness))
+
+    def begin_fills(self):
+        self.fills_ended = False
+    def end_fills(self):
+        self.finalize_path()
+        self.fills_ended = True
+
+    def finalize_path(self):
+        if self.path is not None and len(self.path_data) > 0:
+            self.path_data = self.path_data.rstrip()
+            self.path.set("d", self.path_data)
+            self.g.append(self.path)
+        self.path = self._e.path()
+        super(SVGShapeExporter, self).finalize_path()
+
+
+class BaseExporter(object):
+    def __init__(self, swf=None, shape_exporter=None, force_stroke=False):
+        self.shape_exporter = SVGShapeExporter() if shape_exporter is None else shape_exporter
+        self.clip_depth = 0
+        self.mask_id = None
+        self.jpegTables = None
+        self.force_stroke = force_stroke
+        if swf is not None:
+            self.export(swf)
+
+    def export(self, swf, force_stroke=False):
+        self.force_stroke = force_stroke
+        self.export_define_shapes(swf.tags)
+        self.export_display_list(self.get_display_tags(swf.tags))
+
+    def export_define_bits(self, tag):
+        png_buffer = BytesIO()
+        image = None
+        if isinstance(tag, TagDefineBitsJPEG3):
+
+            tag.bitmapData.seek(0)
+            tag.bitmapAlphaData.seek(0, 2)
+            num_alpha = tag.bitmapAlphaData.tell()
+            tag.bitmapAlphaData.seek(0)
+            image = Image.open(tag.bitmapData)
+            if num_alpha > 0:
+                image_width = image.size[0]
+                image_height = image.size[1]
+                image_data = image.getdata()
+                image_data_len = len(image_data)
+                if num_alpha == image_data_len:
+                    buff = b""
+                    for i in range(0, num_alpha):
+                        alpha = ord(tag.bitmapAlphaData.read(1))
+                        rgb = list(image_data[i])
+                        buff += struct.pack("BBBB", rgb[0], rgb[1], rgb[2], alpha)
+                    image = Image.frombytes("RGBA", (image_width, image_height), buff)
+        elif isinstance(tag, TagDefineBitsJPEG2):
+            tag.bitmapData.seek(0)
+            image = Image.open(tag.bitmapData)
+        else:
+            tag.bitmapData.seek(0)
+            if self.jpegTables is not None:
+                buff = BytesIO()
+                self.jpegTables.seek(0)
+                buff.write(self.jpegTables.read())
+                buff.write(tag.bitmapData.read())
+                buff.seek(0)
+                image = Image.open(buff)
+            else:
+                image = Image.open(tag.bitmapData)
+
+        self.export_image(tag, image)
+
+    def export_define_bits_lossless(self, tag):
+        tag.bitmapData.seek(0)
+        image = Image.open(tag.bitmapData)
+        self.export_image(tag, image)
+
+    def export_define_sprite(self, tag, parent=None):
+        display_tags = self.get_display_tags(tag.tags)
+        self.export_display_list(display_tags, parent)
+
+    def export_define_shape(self, tag):
+        self.shape_exporter.debug = isinstance(tag, TagDefineShape4)
+        tag.shapes.export(self.shape_exporter)
+
+    def export_define_shapes(self, tags):
+        for tag in tags:
+            if isinstance(tag, SWFTimelineContainer):
+                self.export_define_sprite(tag)
+                self.export_define_shapes(tag.tags)
+            elif isinstance(tag, TagDefineShape):
+                self.export_define_shape(tag)
+            elif isinstance(tag, TagJPEGTables):
+                if tag.length > 0:
+                    self.jpegTables = tag.jpegTables
+            elif isinstance(tag, TagDefineBits):
+                self.export_define_bits(tag)
+            elif isinstance(tag, TagDefineBitsLossless):
+                self.export_define_bits_lossless(tag)
+            elif isinstance(tag, TagDefineFont):
+                self.export_define_font(tag)
+            elif isinstance(tag, TagDefineText):
+                self.export_define_text(tag)
+
+    def export_display_list(self, tags, parent=None):
+        self.clip_depth = 0
+        for tag in tags:
+            self.export_display_list_item(tag, parent)
+
+    def export_display_list_item(self, tag, parent=None):
+        pass
+
+    def export_image(self, tag, image=None):
+        pass
+
+    def get_display_tags(self, tags, z_sorted=True):
+        dp_tuples = []
+        for tag in tags:
+            if isinstance(tag, TagPlaceObject):
+                dp_tuples.append((tag, tag.depth))
+            elif isinstance(tag, TagShowFrame):
+                break
+        if z_sorted:
+            dp_tuples = sorted(dp_tuples, key=lambda tag_info: tag_info[1])
+        display_tags = []
+        for item in dp_tuples:
+            display_tags.append(item[0])
+        return display_tags
+
+    def serialize(self):
+        return None
+
+class SVGExporter(BaseExporter):
+    def __init__(self, swf=None, margin=0):
+        self._e = objectify.ElementMaker(annotate=False,
+                        namespace=SVG_NS, nsmap={None : SVG_NS, "xlink" : XLINK_NS})
+        self._margin = margin
+        super(SVGExporter, self).__init__(swf)
+
+    def export(self, swf, force_stroke=False):
+        """ Exports the specified SWF to SVG.
+
+        @param swf  The SWF.
+        @param force_stroke Whether to force strokes on non-stroked fills.
+        """
+        self.svg = self._e.svg(version=SVG_VERSION)
+        self.force_stroke = force_stroke
+        self.defs = self._e.defs()
+        self.root = self._e.g()
+        self.svg.append(self.defs)
+        self.svg.append(self.root)
+        self.shape_exporter.defs = self.defs
+        self._num_filters = 0
+        self.fonts = dict([(x.characterId,x) for x in swf.all_tags_of_type(TagDefineFont)])
+        self.fontInfos = dict([(x.characterId,x) for x in swf.all_tags_of_type(TagDefineFontInfo)])
+
+        # GO!
+        super(SVGExporter, self).export(swf, force_stroke)
+
+        # Setup svg @width, @height and @viewBox
+        # and add the optional margin
+        self.bounds = SVGBounds(self.svg)
+        self.svg.set("width", "%dpx" % round(self.bounds.width))
+        self.svg.set("height", "%dpx" % round(self.bounds.height))
+        if self._margin > 0:
+            self.bounds.grow(self._margin)
+        vb = [self.bounds.minx, self.bounds.miny,
+              self.bounds.width, self.bounds.height]
+        self.svg.set("viewBox", "%s" % " ".join(map(str,vb)))
+
+        # Return the SVG as BytesIO
+        return self._serialize()
+
+    def _serialize(self):
+        return BytesIO(etree.tostring(self.svg,
+                encoding="UTF-8", xml_declaration=True))
+
+    def export_define_sprite(self, tag, parent=None):
+        id = "c%d"%tag.characterId
+        g = self._e.g(id=id)
+        self.defs.append(g)
+        self.clip_depth = 0
+        super(SVGExporter, self).export_define_sprite(tag, g)
+
+    def export_define_font(self, tag):
+        if not tag.characterId in self.fontInfos:
+            return
+        fontInfo = self.fontInfos[tag.characterId]
+        if not fontInfo.useGlyphText:
+            return
+
+        defs = self._e.defs(id="font_{0}".format(tag.characterId))
+
+        for index, glyph in enumerate(tag.glyphShapeTable):
+            # Export the glyph as a shape and add the path to the "defs"
+            # element to be referenced later when exporting text.
+            code_point = fontInfo.codeTable[index]
+            pathGroup = glyph.export().g.getchildren()
+
+            if len(pathGroup):
+                path = pathGroup[0]
+
+                path.set("id", "font_{0}_{1}".format(tag.characterId, code_point))
+
+                # SWF glyphs are always defined on an EM square of 1024 by 1024 units.
+                path.set("transform", "scale({0})".format(float(1)/EM_SQUARE_LENGTH))
+
+                # We'll be setting the color on the USE element that
+                # references this element.
+                del path.attrib["stroke"]
+                del path.attrib["fill"]
+
+                defs.append(path)
+
+        self.defs.append(defs)
+
+    def export_define_text(self, tag):
+        g = self._e.g(id="c{0}".format(int(tag.characterId)))
+        g.set("class", "text_content")
+
+        x = 0
+        y = 0
+
+        for rec in tag.records:
+            if rec.hasXOffset:
+                x = rec.xOffset/PIXELS_PER_TWIP
+            if rec.hasYOffset:
+                y = rec.yOffset/PIXELS_PER_TWIP
+
+            size = rec.textHeight/PIXELS_PER_TWIP
+            if rec.fontId not in self.fontInfos:
+                continue
+            fontInfo = self.fontInfos[rec.fontId]
+
+            if not fontInfo.useGlyphText:
+                inner_text = ""
+                xValues = []
+
+            for glyph in rec.glyphEntries:
+                code_point = fontInfo.codeTable[glyph.index]
+
+                # Ignore control characters
+                if code_point in range(32):
+                    continue
+
+                if fontInfo.useGlyphText:
+                    use = self._e.use()
+                    use.set(XLINK_HREF, "#font_{0}_{1}".format(rec.fontId, code_point))
+
+                    use.set(
+                        'transform',
+                        "scale({0}) translate({1} {2})".format(
+                            size, float(x)/size, float(y)/size
+                        )
+                    )
+
+                    color = ColorUtils.to_rgb_string(ColorUtils.rgb(rec.textColor))
+                    use.set("style", "fill: {0}; stroke: {0}".format(color))
+
+                    g.append(use)
+                else:
+                    inner_text += chr(code_point)
+                    xValues.append(str(x))
+
+                x = x + float(glyph.advance)/PIXELS_PER_TWIP
+
+            if not fontInfo.useGlyphText:
+                text = self._e.text(inner_text)
+
+                text.set("font-family", fontInfo.fontName)
+                text.set("font-size", str(size))
+                text.set("fill", ColorUtils.to_rgb_string(ColorUtils.rgb(rec.textColor)))
+
+                text.set("y", str(y))
+                text.set("x", " ".join(xValues))
+
+                if fontInfo.bold:
+                    text.set("font-weight", "bold")
+                if fontInfo.italic:
+                    text.set("font-style", "italic")
+
+                g.append(text)
+
+        self.defs.append(g)
+
+    def export_define_shape(self, tag):
+        self.shape_exporter.force_stroke = self.force_stroke
+        super(SVGExporter, self).export_define_shape(tag)
+        shape = self.shape_exporter.g
+        shape.set("id", "c%d" % tag.characterId)
+        self.defs.append(shape)
+
+    def export_display_list_item(self, tag, parent=None):
+        g = self._e.g()
+        use = self._e.use()
+        is_mask = False
+
+        if tag.hasMatrix:
+            use.set("transform", _swf_matrix_to_svg_matrix(tag.matrix))
+        if tag.hasClipDepth:
+            self.mask_id = "mask%d" % tag.characterId
+            self.clip_depth = tag.clipDepth
+            g = self._e.mask(id=self.mask_id)
+            # make sure the mask is completely filled white
+            paths = self.defs.xpath("./svg:g[@id='c%d']/svg:path" % tag.characterId, namespaces=NS)
+            for path in paths:
+                path.set("fill", "#ffffff")
+        elif tag.depth <= self.clip_depth and self.mask_id is not None:
+            g.set("mask", "url(#%s)" % self.mask_id)
+
+        filters = []
+        filter_cxform = None
+        self._num_filters += 1
+        filter_id = "filter%d" % self._num_filters
+        svg_filter = self._e.filter(id=filter_id)
+
+        if tag.hasColorTransform:
+            filter_cxform = self.export_color_transform(tag.colorTransform, svg_filter)
+            filters.append(filter_cxform)
+        if tag.hasFilterList and len(tag.filters) > 0:
+            cxform = "color-xform" if tag.hasColorTransform else None
+            f = self.export_filters(tag, svg_filter, cxform)
+            if len(f) > 0:
+                filters.extend(f)
+        if tag.hasColorTransform or (tag.hasFilterList and len(filters) > 0):
+            self.defs.append(svg_filter)
+            use.set("filter", "url(#%s)" % filter_id)
+
+        use.set(XLINK_HREF, "#c%s" % tag.characterId)
+        g.append(use)
+
+        if is_mask:
+            self.defs.append(g)
+        else:
+            if parent is not None:
+                parent.append(g)
+            else:
+                self.root.append(g)
+        return use
+
+    def export_color_transform(self, cxform, svg_filter, result='color-xform'):
+        fe_cxform = self._e.feColorMatrix()
+        fe_cxform.set("in", "SourceGraphic")
+        fe_cxform.set("type", "matrix")
+        fe_cxform.set("values", " ".join(map(str, cxform.matrix)))
+        fe_cxform.set("result", "cxform")
+
+        fe_composite = self._e.feComposite(operator="in")
+        fe_composite.set("in2", "SourceGraphic")
+        fe_composite.set("result", result)
+
+        svg_filter.append(fe_cxform)
+        svg_filter.append(fe_composite)
+        return result
+
+    def export_filters(self, tag, svg_filter, cxform=None):
+        num_filters = len(tag.filters)
+        elements = []
+        attr_in = None
+        for i in range(0, num_filters):
+            swf_filter = tag.filters[i]
+            #print swf_filter
+            if isinstance(swf_filter, FilterDropShadow):
+                elements.append(self.export_filter_dropshadow(swf_filter, svg_filter, cxform))
+                #print swf_filter.strength
+                pass
+            elif isinstance(swf_filter, FilterBlur):
+                pass
+            elif isinstance(swf_filter, FilterGlow):
+                #attr_in = SVGFilterFactory.export_glow_filter(self._e, svg_filter, attr_in=attr_in)
+                #elements.append(attr_in)
+                pass
+            elif isinstance(swf_filter, FilterBevel):
+                pass
+            elif isinstance(swf_filter, FilterGradientGlow):
+                pass
+            elif isinstance(swf_filter, FilterConvolution):
+                pass
+            elif isinstance(swf_filter, FilterColorMatrix):
+                attr_in = SVGFilterFactory.export_color_matrix_filter(self._e, svg_filter, swf_filter.colorMatrix, svg_filter, attr_in=attr_in)
+                elements.append(attr_in)
+                pass
+            elif isinstance(swf_filter, FilterGradientBevel):
+                pass
+            else:
+                raise Exception("unknown filter: ", swf_filter)
+        return elements
+
+#   <filter id="test-filter" x="-50%" y="-50%" width="200%" height="200%">
+#		<feGaussianBlur in="SourceAlpha" stdDeviation="6" result="blur"/>
+#		<feOffset dy="0" dx="0"/>
+#		<feComposite in2="SourceAlpha" operator="arithmetic"
+#			k2="-1" k3="1" result="shadowDiff"/>
+#		<feFlood flood-color="black" flood-opacity="1"/>
+#		<feComposite in2="shadowDiff" operator="in"/>
+#	</filter>;
+
+    def export_filter_dropshadow(self, swf_filter, svg_filter, blend_in=None, result="offsetBlur"):
+        gauss = self._e.feGaussianBlur()
+        gauss.set("in", "SourceAlpha")
+        gauss.set("stdDeviation", "6")
+        gauss.set("result", "blur")
+        if swf_filter.knockout:
+            composite0 = self._e.feComposite(
+                in2="SourceAlpha", operator="arithmetic",
+                k2="-1", k3="1", result="shadowDiff")
+            flood = self._e.feFlood()
+            flood.set("flood-color", "black")
+            flood.set("flood-opacity", "1")
+            composite1 = self._e.feComposite(
+                in2="shadowDiff", operator="in", result=result)
+            svg_filter.append(gauss)
+            svg_filter.append(composite0)
+            svg_filter.append(flood)
+            svg_filter.append(composite1)
+        else:
+            SVGFilterFactory.create_drop_shadow_filter(self._e, svg_filter,
+                None,
+                swf_filter.blurX/20.0,
+                swf_filter.blurY/20.0,
+                blend_in,
+                result)
+        #print etree.tostring(svg_filter, pretty_print=True)
+        return result
+
+    def export_image(self, tag, image=None):
+        if image is not None:
+            buff = BytesIO()
+            image.save(buff, "PNG")
+            buff.seek(0)
+            with open("C:\\Users\\Administrator\\Desktop\\a.png","wb") as f:
+                f.write(buff.getvalue())
+            data_url = _encode_png(buff.read())
+            img = self._e.image()
+            img.set("id", "c%s" % tag.characterId)
+            img.set("x", "0")
+            img.set("y", "0 ")
+            img.set("width", "%s" % str(image.size[0]))
+            img.set("height", "%s" % str(image.size[1]))
+            img.set(XLINK_HREF, "%s" % data_url)
+            self.defs.append(img)
+
+class SingleShapeSVGExporter(SVGExporter):
+    """
+    An SVG exporter which knows how to export a single shape.
+    """
+    def __init__(self, margin=0):
+        super(SingleShapeSVGExporter, self).__init__(margin = margin)
+
+    def export_single_shape(self, shape_tag, swf):
+        from swf.movie import SWF
+
+        # find a typical use of this shape
+        example_place_objects = [x for x in swf.all_tags_of_type(TagPlaceObject) if x.hasCharacter and x.characterId == shape_tag.characterId]
+
+        if len(example_place_objects):
+            place_object = example_place_objects[0]
+            characters = swf.build_dictionary()
+            ids_to_export = place_object.get_dependencies()
+            ids_exported = set()
+            tags_to_export = []
+
+            # this had better form a dag!
+            while len(ids_to_export):
+                id = ids_to_export.pop()
+                if id in ids_exported or id not in characters:
+                    continue
+                tag = characters[id]
+                ids_to_export.update(tag.get_dependencies())
+                tags_to_export.append(tag)
+                ids_exported.add(id)
+            tags_to_export.reverse()
+            tags_to_export.append(place_object)
+        else:
+            place_object = TagPlaceObject()
+            place_object.hasCharacter = True
+            place_object.characterId = shape_tag.characterId
+            tags_to_export = [ shape_tag, place_object ]
+
+        stunt_swf = SWF()
+        stunt_swf.tags = tags_to_export
+
+        return super(SingleShapeSVGExporter, self).export(stunt_swf)
+
+class SVGFilterFactory(object):
+    # http://commons.oreilly.com/wiki/index.php/SVG_Essentials/Filters
+    # http://dev.opera.com/articles/view/svg-evolution-3-applying-polish/
+
+    @classmethod
+    def create_drop_shadow_filter(cls, e, filter, attr_in=None, blurX=0, blurY=0, blend_in=None, result=None):
+        gaussianBlur = SVGFilterFactory.create_gaussian_blur(e, attr_deviaton="1", result="blur-out")
+        offset = SVGFilterFactory.create_offset(e, "blur-out", blurX, blurY, "the-shadow")
+        blend = SVGFilterFactory.create_blend(e, blend_in, attr_in2="the-shadow", result=result)
+        filter.append(gaussianBlur)
+        filter.append(offset)
+        filter.append(blend)
+        return result
+
+    @classmethod
+    def export_color_matrix_filter(cls, e, filter, matrix, svg_filter, attr_in=None, result='color-matrix'):
+        attr_in = "SourceGraphic" if attr_in is None else attr_in
+        fe_cxform = e.feColorMatrix()
+        fe_cxform.set("in", attr_in)
+        fe_cxform.set("type", "matrix")
+        fe_cxform.set("values", " ".join(map(str, matrix)))
+        fe_cxform.set("result", result)
+        filter.append(fe_cxform)
+        #print etree.tostring(filter, pretty_print=True)
+        return result
+
+    @classmethod
+    def export_glow_filter(cls, e, filter, attr_in=None, result="glow-out"):
+        attr_in = "SourceGraphic" if attr_in is None else attr_in
+        gaussianBlur = SVGFilterFactory.create_gaussian_blur(e, attr_in=attr_in, attr_deviaton="1", result=result)
+        filter.append(gaussianBlur)
+        return result
+
+    @classmethod
+    def create_blend(cls, e, attr_in=None, attr_in2="BackgroundImage", mode="normal", result=None):
+        blend = e.feBlend()
+        attr_in = "SourceGraphic" if attr_in is None else attr_in
+        blend.set("in", attr_in)
+        blend.set("in2", attr_in2)
+        blend.set("mode", mode)
+        if result is not None:
+            blend.set("result", result)
+        return blend
+
+    @classmethod
+    def create_gaussian_blur(cls, e, attr_in="SourceAlpha", attr_deviaton="3", result=None):
+        gaussianBlur = e.feGaussianBlur()
+        gaussianBlur.set("in", attr_in)
+        gaussianBlur.set("stdDeviation", attr_deviaton)
+        if result is not None:
+            gaussianBlur.set("result", result)
+        return gaussianBlur
+
+    @classmethod
+    def create_offset(cls, e, attr_in=None, dx=0, dy=0, result=None):
+        offset = e.feOffset()
+        if attr_in is not None:
+            offset.set("in", attr_in)
+        offset.set("dx", "%d" % round(dx))
+        offset.set("dy", "%d" % round(dy))
+        if result is not None:
+            offset.set("result", result)
+        return offset
+
+class SVGBounds(object):
+    def __init__(self, svg=None):
+        self.minx = 1000000.0
+        self.miny = 1000000.0
+        self.maxx = -self.minx
+        self.maxy = -self.miny
+        self._stack = []
+        self._matrix = self._calc_combined_matrix()
+        if svg is not None:
+            self._svg = svg;
+            self._parse(svg)
+
+    def add_point(self, x, y):
+        self.minx = x if x < self.minx else self.minx
+        self.miny = y if y < self.miny else self.miny
+        self.maxx = x if x > self.maxx else self.maxx
+        self.maxy = y if y > self.maxy else self.maxy
+
+    def set(self, minx, miny, maxx, maxy):
+        self.minx = minx
+        self.miny = miny
+        self.maxx = maxx
+        self.maxy = maxy
+
+    def grow(self, margin):
+        self.minx -= margin
+        self.miny -= margin
+        self.maxx += margin
+        self.maxy += margin
+
+    @property
+    def height(self):
+        return self.maxy - self.miny
+
+    def merge(self, other):
+        self.minx = other.minx if other.minx < self.minx else self.minx
+        self.miny = other.miny if other.miny < self.miny else self.miny
+        self.maxx = other.maxx if other.maxx > self.maxx else self.maxx
+        self.maxy = other.maxy if other.maxy > self.maxy else self.maxy
+
+    def shrink(self, margin):
+        self.minx += margin
+        self.miny += margin
+        self.maxx -= margin
+        self.maxy -= margin
+
+    @property
+    def width(self):
+        return self.maxx - self.minx
+
+    def _parse(self, element):
+
+        if element.get("transform") and element.get("transform").find("matrix") < 0:
+            pass
+
+        if element.get("transform") and element.get("transform").find("matrix") >= 0:
+            self._push_transform(element.get("transform"))
+
+        if element.tag == "{%s}path" % SVG_NS:
+            self._handle_path_data(str(element.get("d")))
+        elif element.tag == "{%s}use" % SVG_NS:
+            href = element.get(XLINK_HREF)
+            if href:
+                href = href.replace("#", "")
+                els = self._svg.xpath("./svg:defs//svg:g[@id='%s']" % href,
+                        namespaces=NS)
+                if len(els) > 0:
+                    self._parse(els[0])
+
+        for child in element.getchildren():
+            if child.tag == "{%s}defs" % SVG_NS: continue
+            self._parse(child)
+
+        if element.get("transform") and element.get("transform").find("matrix") >= 0:
+            self._pop_transform()
+
+    def _build_matrix(self, transform):
+        if transform.find("matrix") >= 0:
+            raw = str(transform).replace("matrix(", "").replace(")", "")
+            f = map(float, re.split("\s+|,", raw))
+            return Matrix2(f[0], f[1], f[2], f[3], f[4], f[5])
+
+    def _calc_combined_matrix(self):
+        m = Matrix2()
+        for mat in self._stack:
+            m.append_matrix(mat)
+        return m
+
+    def _handle_path_data(self, d):
+        parts = re.split("[\s]+", d)
+        for i in range(0, len(parts), 2):
+            try:
+                p0 = parts[i]
+                p1 = parts[i+1]
+                p0 = p0.replace("M", "").replace("L", "").replace("Q", "")
+                p1 = p1.replace("M", "").replace("L", "").replace("Q", "")
+
+                v = [float(p0), float(p1)]
+                w = self._matrix.multiply_point(v)
+                self.minx = w[0] if w[0] < self.minx else self.minx
+                self.miny = w[1] if w[1] < self.miny else self.miny
+                self.maxx = w[0] if w[0] > self.maxx else self.maxx
+                self.maxy = w[1] if w[1] > self.maxy else self.maxy
+            except:
+                continue
+
+    def _pop_transform(self):
+        m = self._stack.pop()
+        self._matrix = self._calc_combined_matrix()
+        return m
+
+    def _push_transform(self, transform):
+        self._stack.append(self._build_matrix(transform))
+        self._matrix = self._calc_combined_matrix()
+
+def _encode_jpeg(data):
+    return "data:image/jpeg;base64," + str(base64.encodestring(data)[:-1])
+
+def _encode_png(data):
+    return "data:image/png;base64," + str(base64.encodestring(data)[:-1])
+
+def _swf_matrix_to_matrix(swf_matrix=None, need_scale=False, need_translate=True, need_rotation=False, unit_div=20.0):
+
+    if swf_matrix is None:
+        values = [1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1]
+    else:
+        values = swf_matrix.to_array()
+        if need_rotation:
+            values[1] /= unit_div
+            values[2] /= unit_div
+        if need_scale:
+            values[0] /= unit_div
+            values[3] /= unit_div
+        if need_translate:
+            values[4] /= unit_div
+            values[5] /= unit_div
+
+    return values
+
+def _swf_matrix_to_svg_matrix(swf_matrix=None, need_scale=False, need_translate=True, need_rotation=False, unit_div=20.0):
+    values = _swf_matrix_to_matrix(swf_matrix, need_scale, need_translate, need_rotation, unit_div)
+    str_values = ",".join(map(str, values))
+    return "matrix(%s)" % str_values
+

+ 229 - 0
format_convert/swf/filters.py

@@ -0,0 +1,229 @@
+from .utils import ColorUtils
+
+class Filter(object):
+    """
+    Base Filter class
+    """
+    def __init__(self, id):
+        self._id = id
+    
+    @property
+    def id(self):
+        """ Return filter ID """
+        return self._id
+        
+    def parse(self, data):
+        '''
+        Parses the filter
+        '''
+        pass
+        
+class FilterDropShadow(Filter):
+    """
+    Drop Shadow Filter
+    """
+    def __init__(self, id):
+        super(FilterDropShadow, self).__init__(id)
+    
+    def parse(self, data):
+        self.dropShadowColor = data.readRGBA()
+        self.blurX = data.readFIXED()
+        self.blurY = data.readFIXED()
+        self.angle = data.readFIXED()
+        self.distance = data.readFIXED()
+        self.strength = data.readFIXED8()
+        flags = data.readUI8()
+        self.innerShadow = ((flags & 0x80) != 0)
+        self.knockout = ((flags & 0x40) != 0)
+        self.compositeSource = ((flags & 0x20) != 0)
+        self.passes = flags & 0x1f
+    
+    def __str__(self):
+        s = "[DropShadowFilter] " + \
+            "DropShadowColor: %s" % ColorUtils.to_rgb_string(self.dropShadowColor) + ", " + \
+            "BlurX: %0.2f" % self.blurX + ", " + \
+            "BlurY: %0.2f" % self.blurY + ", " + \
+            "Angle: %0.2f" % self.angle + ", " + \
+            "Distance: %0.2f" % self.distance + ", " + \
+            "Strength: %0.2f" % self.strength + ", " + \
+            "Passes: %d" % self.passes + ", " + \
+            "InnerShadow: %d" % self.innerShadow + ", " + \
+            "Knockout: %d" % self.knockout + ", " + \
+            "CompositeSource: %d" % self.compositeSource
+        return s
+        
+class FilterBlur(Filter):
+    """
+    Blur Filter
+    """
+    def __init__(self, id):
+        super(FilterBlur, self).__init__(id)
+
+    def parse(self, data):
+        self.blurX = data.readFIXED()
+        self.blurY = data.readFIXED()
+        self.passes = data.readUI8() >> 3
+    
+    def __str__(self):
+        s = "[FilterBlur] " + \
+            "BlurX: %0.2f" % self.blurX + ", " + \
+            "BlurY: %0.2f" % self.blurY + ", " + \
+            "Passes: %d" % self.passes
+        return s
+        
+class FilterGlow(Filter):
+    """
+    Glow Filter
+    """
+    def __init__(self, id):
+        super(FilterGlow, self).__init__(id)
+
+    def parse(self, data):
+        self.glowColor = data.readRGBA()
+        self.blurX = data.readFIXED()
+        self.blurY = data.readFIXED()
+        self.strength = data.readFIXED8()
+        flags = data.readUI8()
+        self.innerGlow = ((flags & 0x80) != 0)
+        self.knockout = ((flags & 0x40) != 0)
+        self.compositeSource = ((flags & 0x20) != 0)
+        self.passes = flags & 0x1f
+        
+    def __str__(self):
+        s = "[FilterGlow] " + \
+            "glowColor: %s" % ColorUtils.to_rgb_string(self.glowColor) + ", " + \
+            "BlurX: %0.2f" % self.blurX + ", " + \
+            "BlurY: %0.2f" % self.blurY + ", " + \
+            "Strength: %0.2f" % self.strength + ", " + \
+            "Passes: %d" % self.passes + ", " + \
+            "InnerGlow: %d" % self.innerGlow + ", " + \
+            "Knockout: %d" % self.knockout
+        return s
+            
+class FilterBevel(Filter):
+    """
+    Bevel Filter
+    """
+    def __init__(self, id):
+        super(FilterBevel, self).__init__(id)
+
+    def parse(self, data):
+        self.shadowColor = data.readRGBA()
+        self.highlightColor = data.readRGBA()
+        self.blurX = data.readFIXED()
+        self.blurY = data.readFIXED()
+        self.angle = data.readFIXED()
+        self.distance = data.readFIXED()
+        self.strength = data.readFIXED8()
+        flags = data.readUI8()
+        self.innerShadow = ((flags & 0x80) != 0)
+        self.knockout = ((flags & 0x40) != 0)
+        self.compositeSource = ((flags & 0x20) != 0)
+        self.onTop = ((flags & 0x10) != 0)
+        self.passes = flags & 0x0f
+        
+    def __str__(self):
+        s = "[FilterBevel] " + \
+            "ShadowColor: %s" % ColorUtils.to_rgb_string(self.shadowColor) + ", " + \
+            "HighlightColor: %s" % ColorUtils.to_rgb_string(self.highlightColor) + ", " + \
+            "BlurX: %0.2f" % self.blurX + ", " + \
+            "BlurY: %0.2f" % self.blurY + ", " + \
+            "Angle: %0.2f" % self.angle + ", " + \
+            "Passes: %d" % self.passes + ", " + \
+            "Knockout: %d" % self.knockout
+        return s
+          
+class FilterGradientGlow(Filter):
+    """
+    Gradient Glow Filter
+    """
+    def __init__(self, id):
+        self.gradientColors = []
+        self.gradientRatios = []
+        super(FilterGradientGlow, self).__init__(id)
+
+    def parse(self, data):
+        self.gradientColors = []
+        self.gradientRatios = []
+        self.numColors = data.readUI8()
+        for i in range(0, self.numColors):
+            self.gradientColors.append(data.readRGBA())
+        for i in range(0, self.numColors):
+            self.gradientRatios.append(data.readUI8())
+        self.blurX = data.readFIXED()
+        self.blurY = data.readFIXED()
+        self.strength = data.readFIXED8()
+        flags = data.readUI8()
+        self.innerShadow = ((flags & 0x80) != 0)
+        self.knockout = ((flags & 0x40) != 0)
+        self.compositeSource = ((flags & 0x20) != 0)
+        self.onTop = ((flags & 0x20) != 0)
+        self.passes = flags & 0x0f
+
+class FilterConvolution(Filter):
+    """
+    Convolution Filter
+    """
+    def __init__(self, id):
+        self.matrix = []
+        super(FilterConvolution, self).__init__(id)
+
+    def parse(self, data):
+        self.matrix = []
+        self.matrixX = data.readUI8()
+        self.matrixY = data.readUI8()
+        self.divisor = data.readFLOAT()
+        self.bias = data.readFLOAT()
+        length = matrixX * matrixY
+        for i in range(0, length):
+            self.matrix.append(data.readFLOAT())
+        self.defaultColor = data.readRGBA()
+        flags = data.readUI8()
+        self.clamp = ((flags & 0x02) != 0)
+        self.preserveAlpha = ((flags & 0x01) != 0)
+
+class FilterColorMatrix(Filter):
+    """
+    ColorMatrix Filter
+    """
+    def __init__(self, id):
+        self.colorMatrix = []
+        super(FilterColorMatrix, self).__init__(id)
+
+    def parse(self, data):
+        self.colorMatrix = []
+        for i in range(0, 20):
+            self.colorMatrix.append(data.readFLOAT())
+        for i in range(4, 20, 5):
+            self.colorMatrix[i] /= 256.0
+            
+    def tostring(self):
+        s = "[FilterColorMatrix] " + \
+            " ".join(map(str, self.colorMatrix))
+        return s
+                
+class FilterGradientBevel(FilterGradientGlow):
+    """
+    Gradient Bevel Filter
+    """
+    def __init__(self, id):
+        super(FilterGradientBevel, self).__init__(id)
+                                  
+class SWFFilterFactory(object):
+    """
+    Filter factory
+    """
+    @classmethod
+    def create(cls, type):
+        """ Return the specified Filter """
+        if type == 0: return FilterDropShadow(id)
+        elif type == 1: return FilterBlur(id)
+        elif type == 2: return FilterGlow(id)
+        elif type == 3: return FilterBevel(id)
+        elif type == 4: return FilterGradientGlow(id)
+        elif type == 5: return FilterConvolution(id)
+        elif type == 6: return FilterColorMatrix(id)
+        elif type == 7: return FilterGradientBevel(id)
+        else:
+            raise Exception("Unknown filter type: %d" % type)
+

+ 371 - 0
format_convert/swf/geom.py

@@ -0,0 +1,371 @@
+import math
+
+SNAP = 0.001
+
+class Vector2(object):
+    def __init__(self, x=0.0, y=0.0):
+        self.x = x
+        self.y = y
+        
+class Vector3(object):
+    def __init__(self, x=0, y=0, z=0):
+        self.x = x
+        self.y = y
+        self.z = z
+    
+    def clone(self):
+        return Vector3(self.x, self.y, self.z)
+    
+    def cross(self, v1, v2):
+        self.x = v1.y * v2.z - v1.z * v2.y
+        self.y = v1.z * v2.x - v1.x * v2.z
+        self.z = v1.x * v2.y - v1.y * v2.x
+        return self
+    
+    def distance(self, v):
+        dx = self.x - v.x
+        dy = self.y - v.y
+        dz = self.z - v.z
+        return math.sqrt(dx*dx + dy*dy + dz*dz)
+    
+    def distanceSq(self, v):
+        dx = self.x - v.x
+        dy = self.y - v.y
+        dz = self.z - v.z
+        return (dx*dx + dy*dy + dz*dz)
+    
+    def dot(self, v):
+        return self.x * v.x + self.y * v.y + self.z * v.z
+    
+    def length(self):
+        return math.sqrt(self.x*self.x + self.y*self.y + self.z * self.z)
+    
+    def lengthSq(self):
+        return (self.x*self.x + self.y*self.y + self.z * self.z)
+    
+    def addScalar(self, s):
+        self.x += s
+        self.y += s
+        self.z += s
+        return self
+    
+    def divScalar(self, s):
+        self.x /= s
+        self.y /= s
+        self.z /= s
+        return self
+    
+    def multScalar(self, s):
+        self.x *= s
+        self.y *= s
+        self.z *= s
+        return self
+    
+    def sub(self, a, b):
+        self.x = a.x - b.x
+        self.y = a.y - b.y
+        self.z = a.z - b.z
+        return self
+    
+    def subScalar(self, s):
+        self.x -= s
+        self.y -= s
+        self.z -= s
+        return self
+    
+    def equals(self, v, e=None):
+        e = SNAP if e is None else e
+        if v.x > self.x-e and v.x < self.x+e and \
+           v.y > self.y-e and v.y < self.y+e and \
+           v.z > self.z-e and v.z < self.z+e:
+            return True
+        else:
+            return False
+        
+    def normalize(self):
+        len = self.length()
+        if len > 0.0:
+            self.multScalar(1.0 / len)
+        return self
+    
+    def set(self, x, y, z):
+        self.x = x
+        self.y = y
+        self.z = z
+
+    def tostring(self):
+        return "%0.3f %0.3f %0.3f" % (self.x, self.y, self.z)
+        
+class Matrix2(object):
+    """
+    Matrix2
+    """
+    def __init__(self, a=1.0, b=0.0, c=0.0, d=1.0, tx=0.0, ty=0.0):
+        self.a = a
+        self.b = b
+        self.c = c 
+        self.d = d
+        self.tx = tx
+        self.ty = ty
+        
+    def append(self, a, b, c, d, tx, ty):
+        a1 = self.a
+        b1 = self.b
+        c1 = self.c
+        d1 = self.d
+
+        self.a  = a*a1+b*c1
+        self.b  = a*b1+b*d1
+        self.c  = c*a1+d*c1
+        self.d  = c*b1+d*d1
+        self.tx = tx*a1+ty*c1+self.tx
+        self.ty = tx*b1+ty*d1+self.ty
+     
+    def append_matrix(self, m):
+        self.append(m.a, m.b, m.c, m.d, m.tx, m.ty)
+    
+    def multiply_point(self, vec):
+        return [
+            self.a*vec[0] + self.c*vec[1] + self.tx,
+            self.b*vec[0] + self.d*vec[1] + self.ty
+        ]
+        
+    def prepend(self, a, b, c, d, tx, ty):
+        tx1 = self.tx
+        if (a != 1.0 or b != 0.0 or c != 0.0 or d != 1.0):
+            a1 = self.a
+            c1 = self.c
+            self.a  = a1*a+self.b*c
+            self.b  = a1*b+self.b*d
+            self.c  = c1*a+self.d*c
+            self.d  = c1*b+self.d*d
+        self.tx = tx1*a+self.ty*c+tx
+        self.ty = tx1*b+self.ty*d+ty
+        
+    def prepend_matrix(self, m):
+        self.prepend(m.a, m.b, m.c, m.d, m.tx, m.ty)
+        
+    def rotate(self, angle):
+        cos = math.cos(angle)
+        sin = math.sin(angle)
+        a1 = self.a
+        c1 = self.c
+        tx1 = self.tx
+        self.a = a1*cos-self.b*sin
+        self.b = a1*sin+self.b*cos
+        self.c = c1*cos-self.d*sin
+        self.d = c1*sin+self.d*cos
+        self.tx = tx1*cos-self.ty*sin
+        self.ty = tx1*sin+self.ty*cos
+    
+    def scale(self, x, y):
+        self.a *= x;
+        self.d *= y;
+        self.tx *= x;
+        self.ty *= y;
+     
+    def translate(self, x, y):   
+        self.tx += x;
+        self.ty += y;
+              
+class Matrix4(object):
+    """
+    Matrix4
+    """
+    def __init__(self, data=None):
+        if not data is None and len(data) == 16:
+            self.n11 = data[0]; self.n12 = data[1]; self.n13 = data[2]; self.n14 = data[3]
+            self.n21 = data[4]; self.n22 = data[5]; self.n23 = data[6]; self.n24 = data[7]
+            self.n31 = data[8]; self.n32 = data[9]; self.n33 = data[10]; self.n34 = data[11]
+            self.n41 = data[12]; self.n42 = data[13]; self.n43 = data[14]; self.n44 = data[15]
+        else:
+            self.n11 = 1.0; self.n12 = 0.0; self.n13 = 0.0; self.n14 = 0.0
+            self.n21 = 0.0; self.n22 = 1.0; self.n23 = 0.0; self.n24 = 0.0
+            self.n31 = 0.0; self.n32 = 0.0; self.n33 = 1.0; self.n34 = 0.0
+            self.n41 = 0.0; self.n42 = 0.0; self.n43 = 0.0; self.n44 = 1.0
+    
+    def clone(self):
+        return Matrix4(self.flatten())
+    
+    def flatten(self):
+        return [self.n11, self.n12, self.n13, self.n14, \
+                self.n21, self.n22, self.n23, self.n24, \
+                self.n31, self.n32, self.n33, self.n34, \
+                self.n41, self.n42, self.n43, self.n44]
+         
+    def identity(self):
+        self.n11 = 1.0; self.n12 = 0.0; self.n13 = 0.0; self.n14 = 0.0
+        self.n21 = 0.0; self.n22 = 1.0; self.n23 = 0.0; self.n24 = 0.0
+        self.n31 = 0.0; self.n32 = 0.0; self.n33 = 1.0; self.n34 = 0.0
+        self.n41 = 0.0; self.n42 = 0.0; self.n43 = 0.0; self.n44 = 1.0
+        return self
+    
+    def multiply(self, a, b):
+        a11 = a.n11; a12 = a.n12; a13 = a.n13; a14 = a.n14
+        a21 = a.n21; a22 = a.n22; a23 = a.n23; a24 = a.n24
+        a31 = a.n31; a32 = a.n32; a33 = a.n33; a34 = a.n34
+        a41 = a.n41; a42 = a.n42; a43 = a.n43; a44 = a.n44
+        b11 = b.n11; b12 = b.n12; b13 = b.n13; b14 = b.n14
+        b21 = b.n21; b22 = b.n22; b23 = b.n23; b24 = b.n24
+        b31 = b.n31; b32 = b.n32; b33 = b.n33; b34 = b.n34
+        b41 = b.n41; b42 = b.n42; b43 = b.n43; b44 = b.n44
+
+        self.n11 = a11 * b11 + a12 * b21 + a13 * b31 + a14 * b41
+        self.n12 = a11 * b12 + a12 * b22 + a13 * b32 + a14 * b42
+        self.n13 = a11 * b13 + a12 * b23 + a13 * b33 + a14 * b43
+        self.n14 = a11 * b14 + a12 * b24 + a13 * b34 + a14 * b44
+
+        self.n21 = a21 * b11 + a22 * b21 + a23 * b31 + a24 * b41
+        self.n22 = a21 * b12 + a22 * b22 + a23 * b32 + a24 * b42
+        self.n23 = a21 * b13 + a22 * b23 + a23 * b33 + a24 * b43
+        self.n24 = a21 * b14 + a22 * b24 + a23 * b34 + a24 * b44
+
+        self.n31 = a31 * b11 + a32 * b21 + a33 * b31 + a34 * b41
+        self.n32 = a31 * b12 + a32 * b22 + a33 * b32 + a34 * b42
+        self.n33 = a31 * b13 + a32 * b23 + a33 * b33 + a34 * b43
+        self.n34 = a31 * b14 + a32 * b24 + a33 * b34 + a34 * b44
+
+        self.n41 = a41 * b11 + a42 * b21 + a43 * b31 + a44 * b41
+        self.n42 = a41 * b12 + a42 * b22 + a43 * b32 + a44 * b42
+        self.n43 = a41 * b13 + a42 * b23 + a43 * b33 + a44 * b43
+        self.n44 = a41 * b14 + a42 * b24 + a43 * b34 + a44 * b44
+        return self
+    
+    def multiplyVector3(self, vec):
+        vx = vec[0]
+        vy = vec[1]
+        vz = vec[2]
+        d = 1.0 / (self.n41 * vx + self.n42 * vy + self.n43 * vz + self.n44)
+        x = (self.n11 * vx + self.n12 * vy + self.n13 * vz + self.n14) * d
+        y = (self.n21 * vx + self.n22 * vy + self.n23 * vz + self.n24) * d
+        z = (self.n31 * vx + self.n32 * vy + self.n33 * vz + self.n34) * d
+        return [x, y, z]
+    
+    def multiplyVec3(self, vec):
+        vx = vec.x 
+        vy = vec.y
+        vz = vec.z
+        d = 1.0 / (self.n41 * vx + self.n42 * vy + self.n43 * vz + self.n44)
+        x = (self.n11 * vx + self.n12 * vy + self.n13 * vz + self.n14) * d
+        y = (self.n21 * vx + self.n22 * vy + self.n23 * vz + self.n24) * d
+        z = (self.n31 * vx + self.n32 * vy + self.n33 * vz + self.n34) * d
+        return Vector3(x, y, z)
+    
+    def multiplyVector4(self, v):
+        vx = v[0]; vy = v[1]; vz = v[2]; vw = v[3];
+
+        x = self.n11 * vx + self.n12 * vy + self.n13 * vz + self.n14 * vw;
+        y = self.n21 * vx + self.n22 * vy + self.n23 * vz + self.n24 * vw;
+        z = self.n31 * vx + self.n32 * vy + self.n33 * vz + self.n34 * vw;
+        w = self.n41 * vx + self.n42 * vy + self.n43 * vz + self.n44 * vw;
+
+        return [x, y, z, w];
+    
+    def det(self):
+        #( based on http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm )
+        return (
+            self.n14 * self.n23 * self.n32 * self.n41-
+            self.n13 * self.n24 * self.n32 * self.n41-
+            self.n14 * self.n22 * self.n33 * self.n41+
+            self.n12 * self.n24 * self.n33 * self.n41+
+
+            self.n13 * self.n22 * self.n34 * self.n41-
+            self.n12 * self.n23 * self.n34 * self.n41-
+            self.n14 * self.n23 * self.n31 * self.n42+
+            self.n13 * self.n24 * self.n31 * self.n42+
+
+            self.n14 * self.n21 * self.n33 * self.n42-
+            self.n11 * self.n24 * self.n33 * self.n42-
+            self.n13 * self.n21 * self.n34 * self.n42+
+            self.n11 * self.n23 * self.n34 * self.n42+
+
+            self.n14 * self.n22 * self.n31 * self.n43-
+            self.n12 * self.n24 * self.n31 * self.n43-
+            self.n14 * self.n21 * self.n32 * self.n43+
+            self.n11 * self.n24 * self.n32 * self.n43+
+
+            self.n12 * self.n21 * self.n34 * self.n43-
+            self.n11 * self.n22 * self.n34 * self.n43-
+            self.n13 * self.n22 * self.n31 * self.n44+
+            self.n12 * self.n23 * self.n31 * self.n44+
+
+            self.n13 * self.n21 * self.n32 * self.n44-
+            self.n11 * self.n23 * self.n32 * self.n44-
+            self.n12 * self.n21 * self.n33 * self.n44+
+            self.n11 * self.n22 * self.n33 * self.n44)
+        
+    def lookAt(self, eye, center, up):
+        x = Vector3(); y = Vector3(); z = Vector3();
+        z.sub(eye, center).normalize();
+        x.cross(up, z).normalize();
+        y.cross(z, x).normalize();
+        #eye.normalize()
+        self.n11 = x.x; self.n12 = x.y; self.n13 = x.z; self.n14 = -x.dot(eye);
+        self.n21 = y.x; self.n22 = y.y; self.n23 = y.z; self.n24 = -y.dot(eye);
+        self.n31 = z.x; self.n32 = z.y; self.n33 = z.z; self.n34 = -z.dot(eye);
+        self.n41 = 0.0; self.n42 = 0.0; self.n43 = 0.0; self.n44 = 1.0;
+        return self;
+    
+    def multiplyScalar(self, s):
+        self.n11 *= s; self.n12 *= s; self.n13 *= s; self.n14 *= s;
+        self.n21 *= s; self.n22 *= s; self.n23 *= s; self.n24 *= s;
+        self.n31 *= s; self.n32 *= s; self.n33 *= s; self.n34 *= s;
+        self.n41 *= s; self.n42 *= s; self.n43 *= s; self.n44 *= s;
+        return self
+    
+    @classmethod
+    def inverse(cls, m1):
+        # TODO: make this more efficient
+        #( based on http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm )
+        m2 = Matrix4();
+        m2.n11 = m1.n23*m1.n34*m1.n42 - m1.n24*m1.n33*m1.n42 + m1.n24*m1.n32*m1.n43 - m1.n22*m1.n34*m1.n43 - m1.n23*m1.n32*m1.n44 + m1.n22*m1.n33*m1.n44;
+        m2.n12 = m1.n14*m1.n33*m1.n42 - m1.n13*m1.n34*m1.n42 - m1.n14*m1.n32*m1.n43 + m1.n12*m1.n34*m1.n43 + m1.n13*m1.n32*m1.n44 - m1.n12*m1.n33*m1.n44;
+        m2.n13 = m1.n13*m1.n24*m1.n42 - m1.n14*m1.n23*m1.n42 + m1.n14*m1.n22*m1.n43 - m1.n12*m1.n24*m1.n43 - m1.n13*m1.n22*m1.n44 + m1.n12*m1.n23*m1.n44;
+        m2.n14 = m1.n14*m1.n23*m1.n32 - m1.n13*m1.n24*m1.n32 - m1.n14*m1.n22*m1.n33 + m1.n12*m1.n24*m1.n33 + m1.n13*m1.n22*m1.n34 - m1.n12*m1.n23*m1.n34;
+        m2.n21 = m1.n24*m1.n33*m1.n41 - m1.n23*m1.n34*m1.n41 - m1.n24*m1.n31*m1.n43 + m1.n21*m1.n34*m1.n43 + m1.n23*m1.n31*m1.n44 - m1.n21*m1.n33*m1.n44;
+        m2.n22 = m1.n13*m1.n34*m1.n41 - m1.n14*m1.n33*m1.n41 + m1.n14*m1.n31*m1.n43 - m1.n11*m1.n34*m1.n43 - m1.n13*m1.n31*m1.n44 + m1.n11*m1.n33*m1.n44;
+        m2.n23 = m1.n14*m1.n23*m1.n41 - m1.n13*m1.n24*m1.n41 - m1.n14*m1.n21*m1.n43 + m1.n11*m1.n24*m1.n43 + m1.n13*m1.n21*m1.n44 - m1.n11*m1.n23*m1.n44;
+        m2.n24 = m1.n13*m1.n24*m1.n31 - m1.n14*m1.n23*m1.n31 + m1.n14*m1.n21*m1.n33 - m1.n11*m1.n24*m1.n33 - m1.n13*m1.n21*m1.n34 + m1.n11*m1.n23*m1.n34;
+        m2.n31 = m1.n22*m1.n34*m1.n41 - m1.n24*m1.n32*m1.n41 + m1.n24*m1.n31*m1.n42 - m1.n21*m1.n34*m1.n42 - m1.n22*m1.n31*m1.n44 + m1.n21*m1.n32*m1.n44;
+        m2.n32 = m1.n14*m1.n32*m1.n41 - m1.n12*m1.n34*m1.n41 - m1.n14*m1.n31*m1.n42 + m1.n11*m1.n34*m1.n42 + m1.n12*m1.n31*m1.n44 - m1.n11*m1.n32*m1.n44;
+        m2.n33 = m1.n13*m1.n24*m1.n41 - m1.n14*m1.n22*m1.n41 + m1.n14*m1.n21*m1.n42 - m1.n11*m1.n24*m1.n42 - m1.n12*m1.n21*m1.n44 + m1.n11*m1.n22*m1.n44;
+        m2.n34 = m1.n14*m1.n22*m1.n31 - m1.n12*m1.n24*m1.n31 - m1.n14*m1.n21*m1.n32 + m1.n11*m1.n24*m1.n32 + m1.n12*m1.n21*m1.n34 - m1.n11*m1.n22*m1.n34;
+        m2.n41 = m1.n23*m1.n32*m1.n41 - m1.n22*m1.n33*m1.n41 - m1.n23*m1.n31*m1.n42 + m1.n21*m1.n33*m1.n42 + m1.n22*m1.n31*m1.n43 - m1.n21*m1.n32*m1.n43;
+        m2.n42 = m1.n12*m1.n33*m1.n41 - m1.n13*m1.n32*m1.n41 + m1.n13*m1.n31*m1.n42 - m1.n11*m1.n33*m1.n42 - m1.n12*m1.n31*m1.n43 + m1.n11*m1.n32*m1.n43;
+        m2.n43 = m1.n13*m1.n22*m1.n41 - m1.n12*m1.n23*m1.n41 - m1.n13*m1.n21*m1.n42 + m1.n11*m1.n23*m1.n42 + m1.n12*m1.n21*m1.n43 - m1.n11*m1.n22*m1.n43;
+        m2.n44 = m1.n12*m1.n23*m1.n31 - m1.n13*m1.n22*m1.n31 + m1.n13*m1.n21*m1.n32 - m1.n11*m1.n23*m1.n32 - m1.n12*m1.n21*m1.n33 + m1.n11*m1.n22*m1.n33;
+        m2.multiplyScalar(1.0 / m1.det());
+        return m2;
+    
+    @classmethod
+    def rotationMatrix(cls, x, y, z, angle):
+        rot = Matrix4()
+        c = math.cos(angle)
+        s = math.sin(angle)
+        t = 1 - c
+        rot.n11 = t * x * x + c
+        rot.n12 = t * x * y - s * z
+        rot.n13 = t * x * z + s * y
+        rot.n21 = t * x * y + s * z
+        rot.n22 = t * y * y + c
+        rot.n23 = t * y * z - s * x
+        rot.n31 = t * x * z - s * y
+        rot.n32 = t * y * z + s * x
+        rot.n33 = t * z * z + c
+        return rot
+    
+    @classmethod
+    def scaleMatrix(cls, x, y, z):
+        m = Matrix4()
+        m.n11 = x
+        m.n22 = y
+        m.n33 = z
+        return m
+    
+    @classmethod
+    def translationMatrix(cls, x, y, z):
+        m = Matrix4()
+        m.n14 = x
+        m.n24 = y
+        m.n34 = z
+        return m

+ 171 - 0
format_convert/swf/movie.py

@@ -0,0 +1,171 @@
+"""
+SWF
+"""
+from .tag import SWFTimelineContainer
+from .stream import SWFStream
+from .export import SVGExporter
+try:
+    import cStringIO as StringIO
+except ImportError:
+    from io import BytesIO
+
+class SWFHeaderException(Exception):
+    """ Exception raised in case of an invalid SWFHeader """
+    def __init__(self, message):
+         super(SWFHeaderException, self).__init__(message)
+
+class SWFHeader(object):
+    """ SWF header """
+    def __init__(self, stream):
+        a = stream.readUI8()
+        b = stream.readUI8()
+        c = stream.readUI8()
+        if not a in [0x43, 0x46, 0x5A] or b != 0x57 or c != 0x53:
+            # Invalid signature! ('FWS' or 'CWS' or 'ZFS')
+            raise SWFHeaderException("not a SWF file! (invalid signature)")
+
+        self._compressed_zlib = (a == 0x43)
+        self._compressed_lzma = (a == 0x5A)
+        self._version = stream.readUI8()
+        self._file_length = stream.readUI32()
+        if not (self._compressed_zlib or self._compressed_lzma):
+            self._frame_size = stream.readRECT()
+            self._frame_rate = stream.readFIXED8()
+            self._frame_count = stream.readUI16()
+
+    @property
+    def frame_size(self):
+        """ Return frame size as a SWFRectangle """
+        return self._frame_size
+
+    @property
+    def frame_rate(self):
+        """ Return frame rate """
+        return self._frame_rate
+
+    @property
+    def frame_count(self):
+        """ Return number of frames """
+        return self._frame_count
+                
+    @property
+    def file_length(self):
+        """ Return uncompressed file length """
+        return self._file_length
+                    
+    @property
+    def version(self):
+        """ Return SWF version """
+        return self._version
+                
+    @property
+    def compressed(self):
+        """ Whether the SWF is compressed """
+        return self._compressed_zlib or self._compressed_lzma
+
+    @property
+    def compressed_zlib(self):
+        """ Whether the SWF is compressed using ZLIB """
+        return self._compressed_zlib
+
+    @property
+    def compressed_lzma(self):
+        """ Whether the SWF is compressed using LZMA """
+        return self._compressed_lzma
+        
+    def __str__(self):
+        return "   [SWFHeader]\n" + \
+            "       Version: %d\n" % self.version + \
+            "       FileLength: %d\n" % self.file_length + \
+            "       FrameSize: %s\n" % self.frame_size.__str__() + \
+            "       FrameRate: %d\n" % self.frame_rate + \
+            "       FrameCount: %d\n" % self.frame_count
+
+class SWF(SWFTimelineContainer):
+    """
+    SWF class
+    
+    The SWF (pronounced 'swiff') file format delivers vector graphics, text, 
+    video, and sound over the Internet and is supported by Adobe Flash
+    Player software. The SWF file format is designed to be an efficient 
+    delivery format, not a format for exchanging graphics between graphics 
+    editors.
+    
+    @param file: a file object with read(), seek(), tell() methods.
+    """
+    def __init__(self, file=None):
+        super(SWF, self).__init__()
+        self._data = None if file is None else SWFStream(file)
+        self._header = None
+        if self._data is not None:
+            self.parse(self._data)
+    
+    @property
+    def data(self):
+        """
+        Return the SWFStream object (READ ONLY)
+        """
+        return self._data
+    
+    @property
+    def header(self):
+        """ Return the SWFHeader """
+        return self._header
+        
+    def export(self, exporter=None, force_stroke=False):
+        """
+        Export this SWF using the specified exporter. 
+        When no exporter is passed in the default exporter used 
+        is swf.export.SVGExporter.
+        
+        Exporters should extend the swf.export.BaseExporter class.
+        
+        @param exporter : the exporter to use
+        @param force_stroke : set to true to force strokes on fills,
+                              useful for some edge cases.
+        """
+        exporter = SVGExporter() if exporter is None else exporter
+        if self._data is None:
+            raise Exception("This SWF was not loaded! (no data)")
+        if len(self.tags) == 0:
+            raise Exception("This SWF doesn't contain any tags!")
+        return exporter.export(self, force_stroke)
+            
+    def parse_file(self, filename):
+        """ Parses the SWF from a filename """
+        self.parse(open(filename, 'rb'))
+        
+    def parse(self, data):
+        """ 
+        Parses the SWF.
+        
+        The @data parameter can be a file object or a SWFStream
+        """
+        self._data = data = data if isinstance(data, SWFStream) else SWFStream(data)
+        self._header = SWFHeader(self._data)
+        if self._header.compressed:
+            temp = BytesIO()
+            if self._header.compressed_zlib:
+                import zlib
+                data = data.f.read()
+                zip = zlib.decompressobj()
+                temp.write(zip.decompress(data))
+            else:
+                import pylzma
+                data.readUI32() #consume compressed length
+                data = data.f.read()
+                temp.write(pylzma.decompress(data))
+            temp.seek(0)
+            data = SWFStream(temp)
+        self._header._frame_size = data.readRECT()
+        self._header._frame_rate = data.readFIXED8()
+        self._header._frame_count = data.readUI16()
+        self.parse_tags(data)
+        
+    def __str__(self):
+        s = "[SWF]\n"
+        s += self._header.__str__()
+        for tag in self.tags:
+            s += tag.__str__() + "\n"
+        return s
+        

+ 81 - 0
format_convert/swf/sound.py

@@ -0,0 +1,81 @@
+import consts
+import tag
+import wave
+import stream
+
+supportedCodecs = (
+    consts.AudioCodec.MP3,
+    consts.AudioCodec.UncompressedNativeEndian,
+    consts.AudioCodec.UncompressedLittleEndian,
+)
+
+uncompressed = (
+    consts.AudioCodec.UncompressedNativeEndian,
+    consts.AudioCodec.UncompressedLittleEndian,
+)
+
+REASON_OK = None
+REASON_EMPTY = 'stream is empty'
+
+def get_header(stream_or_tag):
+    if isinstance(stream_or_tag, list):
+        assert len(stream_or_tag) > 0, 'empty stream'
+        return stream_or_tag[0]
+    else:
+        assert isinstance(stream_or_tag, tag.TagDefineSound), 'sound is not a stream or DefineSound tag'
+        return stream_or_tag
+
+def reason_unsupported(stream_or_tag):
+    header = get_header(stream_or_tag)
+    is_stream = isinstance(stream_or_tag, list)
+    
+    if header.soundFormat not in supportedCodecs:
+        return 'codec %s (%d) not supported' % (consts.AudioCodec.tostring(header.soundFormat),
+                                                header.soundFormat)
+    
+    if is_stream and len(stream_or_tag) == 1:
+        return REASON_EMPTY
+    
+    return REASON_OK
+        
+def supported(stream_or_tag):
+    return reason_unsupported(stream_or_tag) is None
+    
+def junk(stream_or_tag):
+    return reason_unsupported(stream_or_tag) == REASON_EMPTY
+
+def get_wave_for_header(header, output):
+    w = wave.open(output, 'w')
+    w.setframerate(consts.AudioSampleRate.Rates[header.soundRate])
+    w.setnchannels(consts.AudioChannels.Channels[header.soundChannels])
+    w.setsampwidth(consts.AudioSampleSize.Bits[header.soundSampleSize] / 8)
+    return w
+    
+def write_stream_to_file(stream, output):
+    header = get_header(stream)
+    
+    w = None
+    if header.soundFormat in uncompressed:
+        w = get_wave_for_header(header, output)
+    
+    for block in stream[1:]:
+        block.complete_parse_with_header(header)
+        
+        if header.soundFormat == consts.AudioCodec.MP3:
+            output.write(block.mpegFrames)
+        else:
+            w.writeframes(block.data.read())
+    
+    if w:
+        w.close()
+
+def write_sound_to_file(st, output):
+    assert isinstance(st, tag.TagDefineSound)
+    if st.soundFormat == consts.AudioCodec.MP3:
+        swfs = stream.SWFStream(st.soundData)
+        seekSamples = swfs.readSI16()
+        output.write(swfs.read())
+    elif st.soundFormat in uncompressed:
+        w = get_wave_for_header(st, output)
+        w.writeframes(st.soundData.read())
+        w.close()

+ 499 - 0
format_convert/swf/stream.py

@@ -0,0 +1,499 @@
+import struct, math
+from .data import *
+from .actions import *
+from .filters import SWFFilterFactory
+
+from functools import reduce
+
+class SWFStream(object):
+    """
+    SWF File stream
+    """
+    FLOAT16_EXPONENT_BASE = 15
+    
+    def __init__(self, file):
+        """ Initialize with a file object """
+        self.f = file
+        self._bits_pending = 0
+        self._partial_byte = None
+        self._make_masks()
+        
+    def bin(self, s):
+        """ Return a value as a binary string """
+        return str(s) if s<=1 else bin(s>>1) + str(s&1)
+        
+    def calc_max_bits(self, signed, values):
+        """ Calculates the maximim needed bits to represent a value """
+        b = 0
+        vmax = -10000000
+        
+        for val in values:
+            if signed:
+                b = b | val if val >= 0 else b | ~val << 1
+                vmax = val if vmax < val else vmax
+            else:
+                b |= val;
+        bits = 0
+        if b > 0:
+            bits = len(self.bin(b)) - 2
+            if signed and vmax > 0 and len(self.bin(vmax)) - 2 >= bits:
+                bits += 1
+        return bits
+    
+    def close(self):
+        """ Closes the stream """
+        if self.f:
+            self.f.close()
+    
+    def _make_masks(self):
+        self._masks = [(1 << x) - 1 for x in range(9)]
+    
+    def _read_bytes_aligned(self, bytes):
+        buf = self.f.read(bytes)
+        return reduce(lambda x, y: x << 8 | ord(chr(y)), buf, 0)
+    
+    def readbits(self, bits):
+        """
+        Read the specified number of bits from the stream.
+        Returns 0 for bits == 0.
+        """
+        
+        if bits == 0:
+            return 0
+        
+        # fast byte-aligned path
+        if bits % 8 == 0 and self._bits_pending == 0:
+            return self._read_bytes_aligned(bits // 8)
+        
+        out = 0
+        masks = self._masks
+        
+        def transfer_bits(x, y, n, t):
+            """
+            transfers t bits from the top of y_n to the bottom of x.
+            then returns x and the remaining bits in y
+            """
+            if n == t:
+                # taking all
+                return (x << t) | y, 0
+            
+            mask = masks[t]           # (1 << t) - 1
+            remainmask = masks[n - t] # (1 << n - t) - 1
+            taken = ((y >> n - t) & mask)
+            return (x << t) | taken, y & remainmask
+        
+        while bits > 0:
+            if self._bits_pending > 0:
+                assert self._partial_byte is not None
+                take = min(self._bits_pending, bits)
+                out, self._partial_byte = transfer_bits(out, self._partial_byte, self._bits_pending, take)
+                
+                if take == self._bits_pending:
+                    # we took them all
+                    self._partial_byte = None
+                self._bits_pending -= take
+                bits -= take
+                continue
+            
+            r = self.f.read(1)
+            if r == '':
+                raise EOFError
+            self._partial_byte = ord(r)
+            self._bits_pending = 8
+        
+        return out
+     
+    def readFB(self, bits):
+        """ Read a float using the specified number of bits """
+        return float(self.readSB(bits)) / 65536.0
+          
+    def readSB(self, bits):
+        """ Read a signed int using the specified number of bits """
+        shift = 32 - bits
+        return int32(self.readbits(bits) << shift) >> shift
+        
+    def readUB(self, bits):
+        """ Read a unsigned int using the specified number of bits """
+        return self.readbits(bits)
+            
+    def readSI8(self):
+        """ Read a signed byte """
+        self.reset_bits_pending();
+        return struct.unpack('b', self.f.read(1))[0]
+            
+    def readUI8(self):
+        """ Read a unsigned byte """
+        self.reset_bits_pending();
+        return struct.unpack('B', self.f.read(1))[0]
+        
+    def readSI16(self):
+        """ Read a signed short """
+        self.reset_bits_pending();
+        return struct.unpack('h', self.f.read(2))[0]
+
+    def readUI16(self):
+        """ Read a unsigned short """
+        self.reset_bits_pending();
+        return struct.unpack('H', self.f.read(2))[0]    
+
+    def readSI32(self):
+        """ Read a signed int """
+        self.reset_bits_pending();
+        return struct.unpack('<i', self.f.read(4))[0]
+
+    def readUI32(self):
+        """ Read a unsigned int """
+        self.reset_bits_pending();
+        return struct.unpack('<I', self.f.read(4))[0]
+
+    def readUI64(self):
+        """ Read a uint64_t """
+        self.reset_bits_pending();
+        return struct.unpack('<Q', self.f.read(8))[0]
+    
+    def readEncodedU32(self):
+        """ Read a encoded unsigned int """
+        self.reset_bits_pending();
+        result = self.readUI8();
+        if result & 0x80 != 0:
+            result = (result & 0x7f) | (self.readUI8() << 7)
+            if result & 0x4000 != 0:
+                result = (result & 0x3fff) | (self.readUI8() << 14)
+                if result & 0x200000 != 0:
+                    result = (result & 0x1fffff) | (self.readUI8() << 21)
+                    if result & 0x10000000 != 0:
+                        result = (result & 0xfffffff) | (self.readUI8() << 28)
+        return result
+  
+    def readFLOAT(self):
+        """ Read a float """
+        self.reset_bits_pending();
+        return struct.unpack('f', self.f.read(4))[0]
+    
+    def readFLOAT16(self):
+        """ Read a 2 byte float """
+        self.reset_bits_pending()
+        word = self.readUI16()
+        sign = -1 if ((word & 0x8000) != 0) else 1
+        exponent = (word >> 10) & 0x1f
+        significand = word & 0x3ff
+        if exponent == 0:
+            if significand == 0:
+                return 0.0
+            else:
+                return sign * math.pow(2, 1 - SWFStream.FLOAT16_EXPONENT_BASE) * (significand / 1024.0)
+        if exponent == 31:
+            if significand == 0:
+                return float('-inf') if sign < 0 else float('inf')
+            else:
+                return float('nan')
+        # normal number
+        return sign * math.pow(2, exponent - SWFStream.FLOAT16_EXPONENT_BASE) * (1 + significand / 1024.0)
+        
+    def readFIXED(self):
+        """ Read a 16.16 fixed value """
+        self.reset_bits_pending()
+        return self.readSI32() / 65536.0
+
+    def readFIXED8(self):
+        """ Read a 8.8 fixed value """
+        self.reset_bits_pending()
+        return self.readSI16() / 256.0
+
+    def readCXFORM(self):
+        """ Read a SWFColorTransform """
+        return SWFColorTransform(self)
+    
+    def readCXFORMWITHALPHA(self):
+        """ Read a SWFColorTransformWithAlpha """
+        return SWFColorTransformWithAlpha(self)
+    
+    def readGLYPHENTRY(self, glyphBits, advanceBits):
+        """ Read a SWFGlyphEntry """
+        return SWFGlyphEntry(self, glyphBits, advanceBits)
+        
+    def readGRADIENT(self, level=1):
+        """ Read a SWFGradient """
+        return SWFGradient(self, level)
+                
+    def readFOCALGRADIENT(self, level=1):
+        """ Read a SWFFocalGradient """
+        return SWFFocalGradient(self, level)
+            
+    def readGRADIENTRECORD(self, level=1):
+        """ Read a SWFColorTransformWithAlpha """
+        return SWFGradientRecord(self, level)
+    
+    def readKERNINGRECORD(self, wideCodes):
+        """ Read a SWFKerningRecord """
+        return SWFKerningRecord(self, wideCodes)
+        
+    def readLANGCODE(self):
+        """ Read a language code """
+        self.reset_bits_pending()
+        return self.readUI8()
+        
+    def readMATRIX(self):
+        """ Read a SWFMatrix """
+        return SWFMatrix(self)
+        
+    def readRECT(self):
+        """ Read a SWFMatrix """
+        r = SWFRectangle()
+        r.parse(self)
+        return r
+    
+    def readSHAPE(self, unit_divisor=20):
+        """ Read a SWFShape """
+        return SWFShape(self, 1, unit_divisor)
+        
+    def readSHAPEWITHSTYLE(self, level=1, unit_divisor=20):
+        """ Read a SWFShapeWithStyle """
+        return SWFShapeWithStyle(self, level, unit_divisor)
+    
+    def readCURVEDEDGERECORD(self, num_bits):
+        """ Read a SWFShapeRecordCurvedEdge """
+        return SWFShapeRecordCurvedEdge(self, num_bits)
+            
+    def readSTRAIGHTEDGERECORD(self, num_bits):
+        """ Read a SWFShapeRecordStraightEdge """
+        return SWFShapeRecordStraightEdge(self, num_bits)
+    
+    def readSTYLECHANGERECORD(self, states, fill_bits, line_bits, level = 1):
+        """ Read a SWFShapeRecordStyleChange """
+        return SWFShapeRecordStyleChange(self, states, fill_bits, line_bits, level)
+        
+    def readFILLSTYLE(self, level=1):
+        """ Read a SWFFillStyle """
+        return SWFFillStyle(self, level)
+    
+    def readTEXTRECORD(self, glyphBits, advanceBits, previousRecord=None, level=1):
+        """ Read a SWFTextRecord """
+        if self.readUI8() == 0:
+            return None
+        else:
+            self.seek(self.tell() - 1)
+            return SWFTextRecord(self, glyphBits, advanceBits, previousRecord, level)
+            
+    def readLINESTYLE(self, level=1):
+        """ Read a SWFLineStyle """
+        return SWFLineStyle(self, level)
+    
+    def readLINESTYLE2(self, level=1):
+        """ Read a SWFLineStyle2 """
+        return SWFLineStyle2(self, level)
+    
+    def readMORPHFILLSTYLE(self, level=1):
+        """ Read a SWFMorphFillStyle """
+        return SWFMorphFillStyle(self, level)
+    
+    def readMORPHLINESTYLE(self, level=1):
+        """ Read a SWFMorphLineStyle """
+        return SWFMorphLineStyle(self, level)
+    
+    def readMORPHLINESTYLE2(self, level=1):
+        """ Read a SWFMorphLineStyle2 """
+        return SWFMorphLineStyle2(self, level)
+        
+    def readMORPHGRADIENT(self, level=1):
+        """ Read a SWFTextRecord """
+        return SWFMorphGradient(self, level)
+     
+    def readMORPHGRADIENTRECORD(self):
+        """ Read a SWFTextRecord """
+        return SWFMorphGradientRecord(self)
+    
+    def readACTIONRECORD(self):
+        """ Read a SWFActionRecord """
+        action = None
+        actionCode = self.readUI8()
+        if actionCode != 0:
+            actionLength = self.readUI16() if actionCode >= 0x80 else 0
+            #print "0x%x"%actionCode, actionLength
+            action = SWFActionFactory.create(actionCode, actionLength)
+            action.parse(self)
+        return action
+        
+    def readACTIONRECORDs(self):
+        """ Read zero or more button records (zero-terminated) """
+        out = []
+        while 1:
+            action = self.readACTIONRECORD()
+            if action:
+                out.append(action)
+            else:
+                break
+        return out
+        
+    def readCLIPACTIONS(self, version):
+        """ Read a SWFClipActions """
+        return SWFClipActions(self, version)
+    
+    def readCLIPACTIONRECORD(self, version):
+        """ Read a SWFClipActionRecord """
+        pos = self.tell()
+        flags = self.readUI32() if version >= 6 else self.readUI16()
+        if flags == 0:
+            return None
+        else:
+            self.seek(pos)
+            return SWFClipActionRecord(self, version)
+            
+    def readCLIPEVENTFLAGS(self, version):
+        """ Read a SWFClipEventFlags """
+        return SWFClipEventFlags(self, version)
+        
+    def readRGB(self):
+        """ Read a RGB color """
+        self.reset_bits_pending();
+        r = self.readUI8()
+        g = self.readUI8()
+        b = self.readUI8()
+        return (0xff << 24) | (r << 16) | (g << 8) | b
+        
+    def readRGBA(self):
+        """ Read a RGBA color """
+        self.reset_bits_pending();
+        r = self.readUI8()
+        g = self.readUI8()
+        b = self.readUI8()
+        a = self.readUI8()
+        return (a << 24) | (r << 16) | (g << 8) | b
+    
+    def readSYMBOL(self):
+        """ Read a SWFSymbol """
+        return SWFSymbol(self)
+        
+    def readString(self):
+        """ Read a string """
+        s = self.f.read(1)
+        string = ""
+        while ord(s) > 0:
+            string += str(s)
+            s = self.f.read(1)
+        return string
+    
+    def readFILTER(self):
+        """ Read a SWFFilter """
+        filterId = self.readUI8()
+        filter = SWFFilterFactory.create(filterId)
+        filter.parse(self)
+        return filter
+    
+    def readFILTERLIST(self):
+        """ Read a length-prefixed list of FILTERs """
+        number = self.readUI8()
+        return [self.readFILTER() for _ in range(number)]
+    
+    def readZONEDATA(self):
+        """ Read a SWFZoneData """
+        return SWFZoneData(self)
+        
+    def readZONERECORD(self):
+        """ Read a SWFZoneRecord """
+        return SWFZoneRecord(self)
+        
+    def readSOUNDINFO(self):
+        """ Read a SWFSoundInfo """
+        return SWFSoundInfo(self)
+        
+    def readSOUNDENVELOPE(self):
+        """ Read a SWFSoundEnvelope """
+        return SWFSoundEnvelope(self)
+    
+    def readBUTTONRECORD(self, version):
+        rc = SWFButtonRecord(data = self, version = version)
+        return rc if rc.valid else None
+        
+    def readBUTTONRECORDs(self, version):
+        """ Read zero or more button records (zero-terminated) """
+        out = []
+        while 1:
+            button = self.readBUTTONRECORD(version)
+            if button:
+                out.append(button)
+            else:
+                break
+        return out
+    
+    def readBUTTONCONDACTION(self):
+        """ Read a size-prefixed BUTTONCONDACTION """
+        size = self.readUI16()
+        if size == 0:
+            return None
+        return SWFButtonCondAction(self)
+    
+    def readBUTTONCONDACTIONSs(self):
+        """ Read zero or more button-condition actions """
+        out = []
+        while 1:
+            action = self.readBUTTONCONDACTION()
+            if action:
+                out.append(action)
+            else:
+                break
+        return out
+        
+    def readEXPORT(self):
+        """ Read a SWFExport """
+        return SWFExport(self)
+    
+    def readMORPHFILLSTYLEARRAY(self):
+        count = self.readUI8()
+        if count == 0xff:
+            count = self.readUI16()
+        return [self.readMORPHFILLSTYLE() for _ in range(count)]
+        
+    def readMORPHLINESTYLEARRAY(self, version):
+        count = self.readUI8()
+        if count == 0xff:
+            count = self.readUI16()
+        kind = self.readMORPHLINESTYLE if version == 1 else self.readMORPHLINESTYLE2
+        return [kind() for _ in range(count)]
+        
+    def readraw_tag(self):
+        """ Read a SWFRawTag """
+        return SWFRawTag(self)
+    
+    def readtag_header(self):
+        """ Read a tag header """
+        pos = self.tell()
+        tag_type_and_length = self.readUI16()
+        tag_length = tag_type_and_length & 0x003f
+        if tag_length == 0x3f:
+            # The SWF10 spec sez that this is a signed int.
+            # Shouldn't it be an unsigned int?
+            tag_length = self.readSI32();
+        return SWFRecordHeader(tag_type_and_length >> 6, tag_length, self.tell() - pos)
+    
+    def skip_bytes(self, length):
+        """ Skip over the specified number of bytes """
+        self.f.seek(self.tell() + length)
+              
+    def reset_bits_pending(self):
+        """ Reset the bit array """
+        self._bits_pending = 0
+    
+    def read(self, count=0):
+        """ Read """
+        return self.f.read(count) if count > 0 else self.f.read()
+        
+    def seek(self, pos, whence=0):
+        """ Seek """
+        self.f.seek(pos, whence)
+        
+    def tell(self):
+        """ Tell """
+        return self.f.tell()
+        
+def int32(x):
+    """ Return a signed or unsigned int """
+    if x>0xFFFFFFFF:
+        raise OverflowError
+    if x>0x7FFFFFFF:
+        x=int(0x100000000-x)
+        if x<2147483648:
+            return -x
+        else:
+            return -2147483648
+    return x

+ 2655 - 0
format_convert/swf/tag.py

@@ -0,0 +1,2655 @@
+from .consts import *
+from .data import *
+from .utils import *
+from .stream import *
+try:
+    import Image
+except ImportError:
+    from PIL import Image
+import struct
+
+try:
+    import cStringIO as StringIO
+except ImportError:
+    from io import BytesIO
+
+class TagFactory(object):
+    @classmethod
+    def create(cls, type):
+        """ Return the created tag by specifying an integer """
+        if type == 0: return TagEnd()
+        elif type == 1: return TagShowFrame()
+        elif type == 2: return TagDefineShape()
+        elif type == 4: return TagPlaceObject()
+        elif type == 5: return TagRemoveObject()
+        elif type == 6: return TagDefineBits()
+        elif type == 7: return TagDefineButton()
+        elif type == 8: return TagJPEGTables()
+        elif type == 9: return TagSetBackgroundColor()
+        elif type == 10: return TagDefineFont()
+        elif type == 11: return TagDefineText()
+        elif type == 12: return TagDoAction()
+        elif type == 13: return TagDefineFontInfo()
+        elif type == 14: return TagDefineSound()
+        elif type == 15: return TagStartSound()
+        elif type == 17: return TagDefineButtonSound()
+        elif type == 18: return TagSoundStreamHead()
+        elif type == 19: return TagSoundStreamBlock()
+        elif type == 20: return TagDefineBitsLossless()
+        elif type == 21: return TagDefineBitsJPEG2()
+        elif type == 22: return TagDefineShape2()
+        elif type == 24: return TagProtect()
+        elif type == 26: return TagPlaceObject2()
+        elif type == 28: return TagRemoveObject2()
+        elif type == 32: return TagDefineShape3()
+        elif type == 33: return TagDefineText2()
+        elif type == 34: return TagDefineButton2()
+        elif type == 35: return TagDefineBitsJPEG3()
+        elif type == 36: return TagDefineBitsLossless2()
+        elif type == 37: return TagDefineEditText()
+        elif type == 39: return TagDefineSprite()
+        elif type == 41: return TagProductInfo()
+        elif type == 43: return TagFrameLabel()
+        elif type == 45: return TagSoundStreamHead2()
+        elif type == 46: return TagDefineMorphShape()
+        elif type == 48: return TagDefineFont2()
+        elif type == 56: return TagExportAssets()
+        elif type == 58: return TagEnableDebugger()
+        elif type == 59: return TagDoInitAction()
+        elif type == 60: return TagDefineVideoStream()
+        elif type == 61: return TagVideoFrame()
+        elif type == 63: return TagDebugID()
+        elif type == 64: return TagEnableDebugger2()
+        elif type == 65: return TagScriptLimits()
+        elif type == 69: return TagFileAttributes()
+        elif type == 70: return TagPlaceObject3()
+        elif type == 73: return TagDefineFontAlignZones()
+        elif type == 74: return TagCSMTextSettings()
+        elif type == 75: return TagDefineFont3()
+        elif type == 76: return TagSymbolClass()
+        elif type == 77: return TagMetadata()
+        elif type == 78: return TagDefineScalingGrid()
+        elif type == 82: return TagDoABC()
+        elif type == 83: return TagDefineShape4()
+        elif type == 84: return TagDefineMorphShape2()
+        elif type == 86: return TagDefineSceneAndFrameLabelData()
+        elif type == 87: return TagDefineBinaryData()
+        elif type == 88: return TagDefineFontName()
+        elif type == 89: return TagStartSound2()
+        else: return None
+
+class Tag(object):
+    def __init__(self):
+        pass
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    @property
+    def name(self):
+        """ The tag name """
+        return ""
+
+    def parse(self, data, length, version=1):
+        """ Parses this tag """
+        pass
+
+    def get_dependencies(self):
+        """ Returns the character ids this tag refers to """
+        return set()
+
+    def __str__(self):
+        return "[%02d:%s]" % (self.type, self.name)
+
+class DefinitionTag(Tag):
+
+    def __init__(self):
+        super(DefinitionTag, self).__init__()
+        self._characterId = -1
+
+    @property
+    def characterId(self):
+        """ Return the character ID """
+        return self._characterId
+
+    @characterId.setter
+    def characterId(self, value):
+        """ Sets the character ID """
+        self._characterId = value
+
+    def parse(self, data, length, version=1):
+        pass
+
+    def get_dependencies(self):
+        s = super(DefinitionTag, self).get_dependencies()
+        s.add(self.characterId)
+        return s
+
+class DisplayListTag(Tag):
+    characterId = -1
+    def __init__(self):
+        super(DisplayListTag, self).__init__()
+
+    def parse(self, data, length, version=1):
+        pass
+
+    def get_dependencies(self):
+        s = super(DisplayListTag, self).get_dependencies()
+        s.add(self.characterId)
+        return s
+
+class SWFTimelineContainer(DefinitionTag):
+    def __init__(self):
+        self.tags = []
+        super(SWFTimelineContainer, self).__init__()
+
+    def get_dependencies(self):
+        """ Returns the character ids this tag refers to """
+        s = super(SWFTimelineContainer, self).get_dependencies()
+        for dt in self.all_tags_of_type(DefinitionTag):
+            s.update(dt.get_dependencies())
+        return s
+
+    def parse_tags(self, data, version=1):
+        pos = data.tell()
+        self.file_length = self._get_file_length(data, pos)
+        tag = None
+        while type(tag) != TagEnd:
+            tag = self.parse_tag(data)
+            if tag:
+                #print tag.name
+                self.tags.append(tag)
+
+    def parse_tag(self, data):
+        pos = data.tell()
+        eof = (pos > self.file_length)
+        if eof:
+            #print "WARNING: end of file encountered, no end tag."
+            return TagEnd()
+        raw_tag = data.readraw_tag()
+        tag_type = raw_tag.header.type
+        tag = TagFactory.create(tag_type)
+        if tag is not None:
+            #print tag.name
+            data.seek(raw_tag.pos_content)
+            data.reset_bits_pending()
+            tag.parse(data, raw_tag.header.content_length, tag.version)
+            #except:
+            #    print "=> tag_error", tag.name
+            data.seek(pos + raw_tag.header.tag_length)
+        else:
+            #print "[WARNING] unhandled tag %s" % (hex(tag_type))
+            data.skip_bytes(raw_tag.header.tag_length)
+        data.seek(pos + raw_tag.header.tag_length)
+        return tag
+
+    def _get_file_length(self, data, pos):
+        data.f.seek(0, 2)
+        length = data.tell()
+        data.f.seek(pos)
+        return length
+
+    def all_tags_of_type(self, type_or_types, recurse_into_sprites = True):
+        """
+        Generator for all tags of the given type_or_types.
+
+        Generates in breadth-first order, optionally including all sub-containers.
+        """
+        for t in self.tags:
+            if isinstance(t, type_or_types):
+                yield t
+        if recurse_into_sprites:
+            for t in self.tags:
+                # recurse into nested sprites
+                if isinstance(t, SWFTimelineContainer):
+                    for containedtag in t.all_tags_of_type(type_or_types):
+                        yield containedtag
+
+    def build_dictionary(self):
+        """
+        Return a dictionary of characterIds to their defining tags.
+        """
+        d = {}
+        for t in self.all_tags_of_type(DefinitionTag, recurse_into_sprites = False):
+            if t.characterId in d:
+                #print 'redefinition of characterId %d:' % (t.characterId)
+                #print '  was:', d[t.characterId]
+                #print 'redef:', t
+                raise ValueError('illegal redefinition of character')
+            d[t.characterId] = t
+        return d
+
+    def collect_sound_streams(self):
+        """
+        Return a list of sound streams in this timeline and its children.
+        The streams are returned in order with respect to the timeline.
+
+        A stream is returned as a list: the first element is the tag
+        which introduced that stream; other elements are the tags
+        which made up the stream body (if any).
+        """
+        rc = []
+        current_stream = None
+        # looking in all containers for frames
+        for tag in self.all_tags_of_type((TagSoundStreamHead, TagSoundStreamBlock)):
+            if isinstance(tag, TagSoundStreamHead):
+                # we have a new stream
+                current_stream = [ tag ]
+                rc.append(current_stream)
+            if isinstance(tag, TagSoundStreamBlock):
+                # we have a frame for the current stream
+                current_stream.append(tag)
+        return rc
+
+    def collect_video_streams(self):
+        """
+        Return a list of video streams in this timeline and its children.
+        The streams are returned in order with respect to the timeline.
+
+        A stream is returned as a list: the first element is the tag
+        which introduced that stream; other elements are the tags
+        which made up the stream body (if any).
+        """
+        rc = []
+        streams_by_id = {}
+
+        # scan first for all streams
+        for t in self.all_tags_of_type(TagDefineVideoStream):
+            stream = [ t ]
+            streams_by_id[t.characterId] = stream
+            rc.append(stream)
+
+        # then find the frames
+        for t in self.all_tags_of_type(TagVideoFrame):
+            # we have a frame for the /named/ stream
+            assert t.streamId in streams_by_id
+            streams_by_id[t.streamId].append(t)
+
+        return rc
+
+class TagEnd(Tag):
+    """
+    The End tag marks the end of a file. This must always be the last tag in a file.
+    The End tag is also required to end a sprite definition.
+    The minimum file format version is SWF 1.
+    """
+    TYPE = 0
+    def __init__(self):
+        super(TagEnd, self).__init__()
+
+    @property
+    def name(self):
+        """ The tag name """
+        return "End"
+
+    @property
+    def type(self):
+        return TagEnd.TYPE
+
+    def __str__(self):
+        return "[%02d:%s]" % (self.type, self.name)
+
+class TagShowFrame(Tag):
+    """
+    The ShowFrame tag instructs Flash Player to display the contents of the
+    display list. The file is paused for the duration of a single frame.
+    The minimum file format version is SWF 1.
+    """
+    TYPE = 1
+    def __init__(self):
+        super(TagShowFrame, self).__init__()
+
+    @property
+    def name(self):
+        return "ShowFrame"
+
+    @property
+    def type(self):
+        return TagShowFrame.TYPE
+
+    def __str__(self):
+        return "[%02d:%s]" % (self.type, self.name)
+
+class TagDefineShape(DefinitionTag):
+    """
+    The DefineShape tag defines a shape for later use by control tags such as
+    PlaceObject. The ShapeId uniquely identifies this shape as 'character' in
+    the Dictionary. The ShapeBounds field is the rectangle that completely
+    encloses the shape. The SHAPEWITHSTYLE structure includes all the paths,
+    fill styles and line styles that make up the shape.
+    The minimum file format version is SWF 1.
+    """
+    TYPE = 2
+
+    def __init__(self):
+        self._shapes = None
+        self._shape_bounds = None
+        super(TagDefineShape, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineShape"
+
+    @property
+    def type(self):
+        return TagDefineShape.TYPE
+
+    @property
+    def shapes(self):
+        """ Return SWFShape """
+        return self._shapes
+
+    @property
+    def shape_bounds(self):
+        """ Return the bounds of this tag as a SWFRectangle """
+        return self._shape_bounds
+
+    def export(self, handler=None):
+        """ Export this tag """
+        return self.shapes.export(handler)
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self._shape_bounds = data.readRECT()
+        self._shapes = data.readSHAPEWITHSTYLE(self.level)
+
+    def get_dependencies(self):
+        s = super(TagDefineShape, self).get_dependencies()
+        s.update(self.shapes.get_dependencies())
+        return s
+
+    def __str__(self):
+        s = super(TagDefineShape, self).__str__( ) + " " + \
+            "ID: %d" % self.characterId + ", " + \
+            "Bounds: " + self._shape_bounds.__str__()
+        #s += "\n%s" % self._shapes.__str__()
+        return s
+
+class TagPlaceObject(DisplayListTag):
+    """
+    The PlaceObject tag adds a character to the display list. The CharacterId
+    identifies the character to be added. The Depth field specifies the
+    stacking order of the character. The Matrix field species the position,
+    scale, and rotation of the character. If the size of the PlaceObject tag
+    exceeds the end of the transformation matrix, it is assumed that a
+    ColorTransform field is appended to the record. The ColorTransform field
+    specifies a color effect (such as transparency) that is applied to the character.
+    The same character can be added more than once to the display list with
+    a different depth and transformation matrix.
+    """
+    TYPE = 4
+    hasClipActions = False
+    hasClipDepth = False
+    hasName = False
+    hasRatio = False
+    hasColorTransform = False
+    hasMatrix = False
+    hasCharacter = False
+    hasMove = False
+    hasImage = False
+    hasClassName = False
+    hasCacheAsBitmap = False
+    hasBlendMode = False
+    hasFilterList = False
+    depth = 0
+    matrix = None
+    colorTransform = None
+    # Forward declarations for TagPlaceObject2
+    ratio = 0
+    instanceName = None
+    clipDepth = 0
+    clipActions = None
+    # Forward declarations for TagPlaceObject3
+    className = None
+    blendMode = 0
+    bitmapCache = 0
+
+    def __init__(self):
+        self._surfaceFilterList = []
+        super(TagPlaceObject, self).__init__()
+
+    def parse(self, data, length, version=1):
+        """ Parses this tag """
+        pos = data.tell()
+        self.characterId = data.readUI16()
+        self.depth = data.readUI16();
+        self.matrix = data.readMATRIX();
+        self.hasCharacter = True;
+        self.hasMatrix = True;
+        if data.tell() - pos < length:
+            colorTransform = data.readCXFORM()
+            self.hasColorTransform = True
+
+    def get_dependencies(self):
+        s = super(TagPlaceObject, self).get_dependencies()
+        if self.hasCharacter:
+            s.add(self.characterId)
+        return s
+
+    @property
+    def filters(self):
+        """ Returns a list of filter """
+        return self._surfaceFilterList
+
+    @property
+    def name(self):
+        return "PlaceObject"
+
+    @property
+    def type(self):
+        return TagPlaceObject.TYPE
+
+    def __str__(self):
+        s = super(TagPlaceObject, self).__str__() + " " + \
+            "Depth: %d, " % self.depth + \
+            "CharacterID: %d" % self.characterId
+        if self.hasName:
+            s+= ", InstanceName: %s" % self.instanceName
+        if self.hasMatrix:
+            s += ", Matrix: %s" % self.matrix.__str__()
+        if self.hasClipDepth:
+            s += ", ClipDepth: %d" % self.clipDepth
+        if self.hasColorTransform:
+            s += ", ColorTransform: %s" % self.colorTransform.__str__()
+        if self.hasFilterList:
+            s += ", Filters: %d" % len(self.filters)
+        if self.hasBlendMode:
+            s += ", Blendmode: %d" % self.blendMode
+        return s
+
+class TagRemoveObject(DisplayListTag):
+    """
+    The RemoveObject tag removes the specified character (at the specified depth)
+    from the display list.
+    The minimum file format version is SWF 1.
+    """
+    TYPE = 5
+    depth = 0
+    def __init__(self):
+        super(TagRemoveObject, self).__init__()
+
+    @property
+    def name(self):
+        return "RemoveObject"
+
+    @property
+    def type(self):
+        return TagRemoveObject.TYPE
+
+    def parse(self, data, length, version=1):
+        """ Parses this tag """
+        self.characterId = data.readUI16()
+        self.depth = data.readUI16()
+
+class TagDefineBits(DefinitionTag):
+    """
+    This tag defines a bitmap character with JPEG compression. It contains only
+    the JPEG compressed image data (from the Frame Header onward). A separate
+    JPEGTables tag contains the JPEG encoding data used to encode this image
+    (the Tables/Misc segment).
+    NOTE:
+        Only one JPEGTables tag is allowed in a SWF file, and thus all bitmaps
+        defined with DefineBits must share common encoding tables.
+    The data in this tag begins with the JPEG SOI marker 0xFF, 0xD8 and ends
+    with the EOI marker 0xFF, 0xD9. Before version 8 of the SWF file format,
+    SWF files could contain an erroneous header of 0xFF, 0xD9, 0xFF, 0xD8 before
+    the JPEG SOI marker.
+    """
+    TYPE = 6
+    bitmapData = None
+    def __init__(self):
+        self.bitmapData = BytesIO()
+        self.bitmapType = BitmapType.JPEG
+        super(TagDefineBits, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineBits"
+
+    @property
+    def type(self):
+        return TagDefineBits.TYPE
+
+    def parse(self, data, length, version=1):
+        self.bitmapData = BytesIO()
+        self.characterId = data.readUI16()
+        if length > 2:
+            self.bitmapData.write(data.f.read(length - 2))
+            self.bitmapData.seek(0)
+
+class TagJPEGTables(DefinitionTag):
+    """
+    This tag defines the JPEG encoding table (the Tables/Misc segment) for all
+    JPEG images defined using the DefineBits tag. There may only be one
+    JPEGTables tag in a SWF file.
+    The data in this tag begins with the JPEG SOI marker 0xFF, 0xD8 and ends
+    with the EOI marker 0xFF, 0xD9. Before version 8 of the SWF file format,
+    SWF files could contain an erroneous header of 0xFF, 0xD9, 0xFF, 0xD8 before
+    the JPEG SOI marker.
+    The minimum file format version for this tag is SWF 1.
+    """
+    TYPE = 8
+    jpegTables = None
+    length = 0
+
+    def __init__(self):
+        super(TagJPEGTables, self).__init__()
+        self.jpegTables = BytesIO()
+
+    @property
+    def name(self):
+        return "JPEGTables"
+
+    @property
+    def type(self):
+        return TagJPEGTables.TYPE
+
+    def parse(self, data, length, version=1):
+        self.length = length
+        if length > 0:
+            self.jpegTables.write(data.f.read(length))
+            self.jpegTables.seek(0)
+
+    def __str__(self):
+        s = super(TagJPEGTables, self).__str__()
+        s += " Length: %d" % self.length
+        return s
+
+class TagSetBackgroundColor(Tag):
+    """
+    The SetBackgroundColor tag sets the background color of the display.
+    The minimum file format version is SWF 1.
+    """
+    TYPE = 9
+    color = 0
+    def __init__(self):
+        super(TagSetBackgroundColor, self).__init__()
+
+    def parse(self, data, length, version=1):
+        self.color = data.readRGB()
+
+    @property
+    def name(self):
+        return "SetBackgroundColor"
+
+    @property
+    def type(self):
+        return TagSetBackgroundColor.TYPE
+
+    def __str__(self):
+        s = super(TagSetBackgroundColor, self).__str__()
+        s += " Color: " + ColorUtils.to_rgb_string(self.color)
+        return s
+
+class TagDefineFont(DefinitionTag):
+    """
+    The DefineFont tag defines the shape outlines of each glyph used in a
+    particular font. Only the glyphs that are used by subsequent DefineText
+    tags are actually defined.
+    DefineFont tags cannot be used for dynamic text. Dynamic text requires
+    the DefineFont2 tag.
+    The minimum file format version is SWF 1.
+    """
+    TYPE= 10
+    offsetTable = []
+    glyphShapeTable = []
+    def __init__(self):
+        super(TagDefineFont, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineFont"
+
+    @property
+    def type(self):
+        return TagDefineFont.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    @property
+    def unitDivisor(self):
+        return 1
+
+    def parse(self, data, length, version=1):
+        self.glyphShapeTable = []
+        self.offsetTable = []
+        self.characterId = data.readUI16()
+
+        # Because the glyph shape table immediately follows the offset table,
+        # the number of entries in each table (the number of glyphs in the
+        # font) can be inferred by dividing the first entry in the offset
+        # table by two.
+        self.offsetTable.append(data.readUI16())
+        numGlyphs = self.offsetTable[0] / 2
+
+        for i in range(1, numGlyphs):
+            self.offsetTable.append(data.readUI16())
+
+        for i in range(numGlyphs):
+            self.glyphShapeTable.append(data.readSHAPE(self.unitDivisor))
+
+class TagDefineText(DefinitionTag):
+    """
+    The DefineText tag defines a block of static text. It describes the font,
+    size, color, and exact position of every character in the text object.
+    The minimum file format version is SWF 1.
+    """
+    TYPE = 11
+    textBounds = None
+    textMatrix = None
+
+    def __init__(self):
+        self._records = []
+        super(TagDefineText, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineText"
+
+    @property
+    def type(self):
+        return TagDefineText.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    def get_dependencies(self):
+        s = super(TagDefineText, self).get_dependencies()
+        for r in self.records:
+            s.update(r.get_dependencies())
+        return s
+
+    @property
+    def records(self):
+        """ Return list of SWFTextRecord """
+        return self._records
+
+    def parse(self, data, length, version=1):
+        self._records = []
+        self.characterId = data.readUI16()
+        self.textBounds = data.readRECT()
+        self.textMatrix = data.readMATRIX()
+        glyphBits = data.readUI8()
+        advanceBits = data.readUI8()
+        record = None
+        record = data.readTEXTRECORD(glyphBits, advanceBits, record, self.level)
+        while not record is None:
+            self._records.append(record)
+            record = data.readTEXTRECORD(glyphBits, advanceBits, record, self.level)
+
+class TagDoAction(Tag):
+    """
+    DoAction instructs Flash Player to perform a list of actions when the
+    current frame is complete. The actions are performed when the ShowFrame
+    tag is encountered, regardless of where in the frame the DoAction tag appears.
+    Starting with SWF 9, if the ActionScript3 field of the FileAttributes tag is 1,
+    the contents of the DoAction tag will be ignored.
+    """
+    TYPE = 12
+    def __init__(self):
+        self._actions = []
+        super(TagDoAction, self).__init__()
+
+    @property
+    def name(self):
+        return "DoAction"
+
+    @property
+    def type(self):
+        """ Return the SWF tag type """
+        return TagDoAction.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        """ Return the minimum SWF version """
+        return 3
+
+    @property
+    def actions(self):
+        """ Return list of SWFActionRecord """
+        return self._actions
+
+    def parse(self, data, length, version=1):
+        self._actions = data.readACTIONRECORDs()
+
+class TagDefineFontInfo(Tag):
+    """
+    The DefineFontInfo tag defines a mapping from a glyph font (defined with DefineFont) to a
+    device font. It provides a font name and style to pass to the playback platform's text engine,
+    and a table of character codes that identifies the character represented by each glyph in the
+    corresponding DefineFont tag, allowing the glyph indices of a DefineText tag to be converted
+    to character strings.
+    The presence of a DefineFontInfo tag does not force a glyph font to become a device font; it
+    merely makes the option available. The actual choice between glyph and device usage is made
+    according to the value of devicefont (see the introduction) or the value of UseOutlines in a
+    DefineEditText tag. If a device font is unavailable on a playback platform, Flash Player will
+    fall back to glyph text.
+    """
+    TYPE = 13
+    def __init__(self):
+        super(TagDefineFontInfo, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineFontInfo"
+
+    @property
+    def type(self):
+        return TagDefineFontInfo.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    @property
+    def unitDivisor(self):
+        return 1
+
+    def get_dependencies(self):
+        s = super(TagDefineFontInfo, self).get_dependencies()
+        s.add(self.characterId)
+        return s
+
+    def parse(self, data, length, version=1):
+        self.codeTable = []
+
+        # FontID
+        self.characterId = data.readUI16()
+
+        fontNameLen = data.readUI8()
+
+        self.fontName = ""
+        self.useGlyphText = False
+
+        # Read in font name, one character at a time. If any of the
+        # characters are non-ASCII, assume that glyph text should be
+        # used rather than device text.
+        for i in range(fontNameLen):
+            ord = data.readUI8()
+
+            if ord in range(128):
+                self.fontName += chr(ord)
+            else:
+                self.useGlyphText = True
+
+        if self.useGlyphText:
+            self.fontName = "Font_{0}".format(self.characterId)
+
+        flags = data.readUI8()
+
+        self.smallText = ((flags & 0x20) != 0)
+        self.shiftJIS = ((flags & 0x10) != 0)
+        self.ansi  = ((flags & 0x08) != 0)
+        self.italic = ((flags & 0x04) != 0)
+        self.bold = ((flags & 0x02) != 0)
+        self.wideCodes = ((flags & 0x01) != 0)
+
+        if self.wideCodes:
+            numGlyphs = (length - 2 - 1 - fontNameLen - 1) / 2
+        else:
+            numGlyphs = length - 2 - 1 - fontNameLen - 1
+
+        for i in range(0, numGlyphs):
+            self.codeTable.append(data.readUI16() if self.wideCodes else data.readUI8())
+
+class TagDefineBitsLossless(DefinitionTag):
+    """
+    Defines a lossless bitmap character that contains RGB bitmap data compressed
+    with ZLIB. The data format used by the ZLIB library is described by
+    Request for Comments (RFCs) documents 1950 to 1952.
+    Two kinds of bitmaps are supported. Colormapped images define a colormap of
+    up to 256 colors, each represented by a 24-bit RGB value, and then use
+    8-bit pixel values to index into the colormap. Direct images store actual
+    pixel color values using 15 bits (32,768 colors) or 24 bits (about 17 million colors).
+    The minimum file format version for this tag is SWF 2.
+    """
+    TYPE = 20
+    bitmapData = None
+    image_buffer = ""
+    bitmap_format = 0
+    bitmap_width = 0
+    bitmap_height = 0
+    bitmap_color_size = 0
+    zlib_bitmap_data = None
+    padded_width = 0
+    def __init__(self):
+        super(TagDefineBitsLossless, self).__init__()
+
+    def parse(self, data, length, version=1):
+        import zlib
+        self.image_buffer = ""
+        self.characterId = data.readUI16()
+        self.bitmap_format = data.readUI8()
+        self.bitmap_width = data.readUI16()
+        self.bitmap_height = data.readUI16()
+        if self.bitmap_format == BitmapFormat.BIT_8:
+            self.bitmap_color_size = data.readUI8()
+            self.zlib_bitmap_data = data.f.read(length-8)
+        else:
+            self.zlib_bitmap_data = data.f.read(length-7)
+
+        # decompress zlib encoded bytes
+        compressed_length = len(self.zlib_bitmap_data)
+        zip = zlib.decompressobj()
+        temp = BytesIO()
+        temp.write(zip.decompress(self.zlib_bitmap_data))
+        temp.seek(0, 2)
+        uncompressed_length = temp.tell()
+        temp.seek(0)
+
+        # padding : should be aligned to 32 bit boundary
+        self.padded_width = self.bitmap_width
+        while self.padded_width % 4 != 0:
+            self.padded_width += 1
+        t = self.padded_width * self.bitmap_height
+
+        is_lossless2 = (type(self) == TagDefineBitsLossless2)
+        im = None
+        self.bitmapData = BytesIO()
+
+        indexed_colors = []
+        if self.bitmap_format == BitmapFormat.BIT_8:
+            for i in range(0, self.bitmap_color_size + 1):
+                r = ord(temp.read(1))
+                g = ord(temp.read(1))
+                b = ord(temp.read(1))
+                a = ord(temp.read(1)) if is_lossless2 else 0xff
+                indexed_colors.append(struct.pack("BBBB", r, g, b, a))
+
+            # create the image buffer
+            s = BytesIO()
+            for i in range(t):
+                a = ord(temp.read(1))
+                s.write(indexed_colors[a%len(indexed_colors)])
+            self.image_buffer = s.getvalue()
+            s.close()
+
+            im = Image.frombytes("RGBA", (self.padded_width, self.bitmap_height), self.image_buffer)
+            im = im.crop((0, 0, self.bitmap_width, self.bitmap_height))
+
+        elif self.bitmap_format == BitmapFormat.BIT_15:
+            raise Exception("DefineBitsLossless: BIT_15 not yet implemented")
+        elif self.bitmap_format == BitmapFormat.BIT_24:
+            # we have no padding, since PIX24s are 32-bit aligned
+            t = self.bitmap_width * self.bitmap_height
+            # read PIX24's
+            s = BytesIO()
+            for i in range(0, t):
+                if not is_lossless2:
+                    temp.read(1) # reserved, always 0
+                a = ord(temp.read(1)) if is_lossless2 else 0xff
+                r = ord(temp.read(1))
+                g = ord(temp.read(1))
+                b = ord(temp.read(1))
+                s.write(struct.pack("BBBB", r, g, b, a))
+            self.image_buffer = s.getvalue()
+            im = Image.frombytes("RGBA", (self.bitmap_width, self.bitmap_height), self.image_buffer)
+        else:
+            raise Exception("unhandled bitmap format! %s %d" % (BitmapFormat.tostring(self.bitmap_format), self.bitmap_format))
+
+        if not im is None:
+            im.save(self.bitmapData, "PNG")
+            self.bitmapData.seek(0)
+            self.bitmapType = ImageUtils.get_image_type(self.bitmapData)
+
+    @property
+    def name(self):
+        return "DefineBitsLossless"
+
+    @property
+    def type(self):
+        return TagDefineBitsLossless.TYPE
+
+class TagDefineBitsJPEG2(TagDefineBits):
+    """
+    This tag defines a bitmap character with JPEG compression. It differs from
+    DefineBits in that it contains both the JPEG encoding table and the JPEG
+    image data. This tag allows multiple JPEG images with differing encoding
+    tables to be defined within a single SWF file.
+    The data in this tag begins with the JPEG SOI marker 0xFF, 0xD8 and ends
+    with the EOI marker 0xFF, 0xD9. Before version 8 of the SWF file format,
+    SWF files could contain an erroneous header of 0xFF, 0xD9, 0xFF, 0xD8
+    before the JPEG SOI marker.
+    In addition to specifying JPEG data, DefineBitsJPEG2 can also contain PNG
+    image data and non-animated GIF89a image data.
+
+    - If ImageData begins with the eight bytes 0x89 0x50 0x4E 0x47 0x0D 0x0A 0x1A 0x0A,
+      the ImageData contains PNG data.
+    - If ImageData begins with the six bytes 0x47 0x49 0x46 0x38 0x39 0x61, the ImageData
+      contains GIF89a data.
+
+    The minimum file format version for this tag is SWF 2. The minimum file format
+    version for embedding PNG of GIF89a data is SWF 8.
+    """
+    TYPE = 21
+    bitmapType = 0
+
+    def __init__(self):
+        super(TagDefineBitsJPEG2, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineBitsJPEG2"
+
+    @property
+    def type(self):
+        return TagDefineBitsJPEG2.TYPE
+
+    @property
+    def version(self):
+        return 2 if self.bitmapType == BitmapType.JPEG else 8
+
+    @property
+    def level(self):
+        return 2
+
+    def parse(self, data, length, version=1):
+        super(TagDefineBitsJPEG2, self).parse(data, length, version)
+        self.bitmapType = ImageUtils.get_image_type(self.bitmapData)
+
+class TagDefineShape2(TagDefineShape):
+    """
+    DefineShape2 extends the capabilities of DefineShape with the ability
+    to support more than 255 styles in the style list and multiple style
+    lists in a single shape.
+    The minimum file format version is SWF 2.
+    """
+    TYPE = 22
+
+    def __init__(self):
+        super(TagDefineShape2, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineShape2"
+
+    @property
+    def type(self):
+        return TagDefineShape2.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 2
+
+class TagPlaceObject2(TagPlaceObject):
+    """
+    The PlaceObject2 tag extends the functionality of the PlaceObject tag.
+    The PlaceObject2 tag can both add a character to the display list, and
+    modify the attributes of a character that is already on the display list.
+    The PlaceObject2 tag changed slightly from SWF 4 to SWF 5. In SWF 5,
+    clip actions were added.
+    The tag begins with a group of flags that indicate which fields are
+    present in the tag. The optional fields are CharacterId, Matrix,
+    ColorTransform, Ratio, ClipDepth, Name, and ClipActions.
+    The Depth field is the only field that is always required.
+    The depth value determines the stacking order of the character.
+    Characters with lower depth values are displayed underneath characters
+    with higher depth values. A depth value of 1 means the character is
+    displayed at the bottom of the stack. Any given depth can have only one
+    character. This means a character that is already on the display list can
+    be identified by its depth alone (that is, a CharacterId is not required).
+    The PlaceFlagMove and PlaceFlagHasCharacter tags indicate whether a new
+    character is being added to the display list, or a character already on the
+    display list is being modified. The meaning of the flags is as follows:
+
+    - PlaceFlagMove = 0 and PlaceFlagHasCharacter = 1 A new character
+      (with ID of CharacterId) is placed on the display list at the specified
+      depth. Other fields set the attributes of this new character.
+    - PlaceFlagMove = 1 and PlaceFlagHasCharacter = 0
+      The character at the specified depth is modified. Other fields modify the
+      attributes of this character. Because any given depth can have only one
+      character, no CharacterId is required.
+    - PlaceFlagMove = 1 and PlaceFlagHasCharacter = 1
+      The character at the specified Depth is removed, and a new character
+      (with ID of CharacterId) is placed at that depth. Other fields set the
+      attributes of this new character.
+      For example, a character that is moved over a series of frames has
+      PlaceFlagHasCharacter set in the first frame, and PlaceFlagMove set in
+      subsequent frames. The first frame places the new character at the desired
+      depth, and sets the initial transformation matrix. Subsequent frames replace
+      the transformation matrix of the character at the desired depth.
+
+    The optional fields in PlaceObject2 have the following meaning:
+    - The CharacterId field specifies the character to be added to the display list.
+      CharacterId is used only when a new character is being added. If a character
+      that is already on the display list is being modified, the CharacterId field is absent.
+    - The Matrix field specifies the position, scale and rotation of the character
+      being added or modified.
+    - The ColorTransform field specifies the color effect applied to the character
+      being added or modified.
+    - The Ratio field specifies a morph ratio for the character being added or modified.
+      This field applies only to characters defined with DefineMorphShape, and controls
+      how far the morph has progressed. A ratio of zero displays the character at the start
+      of the morph. A ratio of 65535 displays the character at the end of the morph.
+      For values between zero and 65535 Flash Player interpolates between the start and end
+      shapes, and displays an in- between shape.
+    - The ClipDepth field specifies the top-most depth that will be masked by the character
+      being added. A ClipDepth of zero indicates that this is not a clipping character.
+    - The Name field specifies a name for the character being added or modified. This field
+      is typically used with sprite characters, and is used to identify the sprite for
+      SetTarget actions. It allows the main file (or other sprites) to perform actions
+      inside the sprite (see 'Sprites and Movie Clips' on page 231).
+    - The ClipActions field, which is valid only for placing sprite characters, defines
+      one or more event handlers to be invoked when certain events occur.
+    """
+    TYPE = 26
+    def __init__(self):
+        super(TagPlaceObject2, self).__init__()
+
+    def parse(self, data, length, version=1):
+        flags = data.readUI8()
+        self.hasClipActions = (flags & 0x80) != 0
+        self.hasClipDepth = (flags & 0x40) != 0
+        self.hasName = (flags & 0x20) != 0
+        self.hasRatio = (flags & 0x10) != 0
+        self.hasColorTransform = (flags & 0x08) != 0
+        self.hasMatrix = (flags & 0x04) != 0
+        self.hasCharacter = (flags & 0x02) != 0
+        self.hasMove = (flags & 0x01) != 0
+        self.depth = data.readUI16()
+        if self.hasCharacter:
+            self.characterId = data.readUI16()
+        if self.hasMatrix:
+            self.matrix = data.readMATRIX()
+        if self.hasColorTransform:
+            self.colorTransform = data.readCXFORMWITHALPHA()
+        if self.hasRatio:
+            self.ratio = data.readUI16()
+        if self.hasName:
+            self.instanceName = data.readString()
+        if self.hasClipDepth:
+            self.clipDepth = data.readUI16()
+        if self.hasClipActions:
+            self.clipActions = data.readCLIPACTIONS(version);
+            #raise Exception("PlaceObject2: ClipActions not yet implemented!")
+
+    @property
+    def name(self):
+        return "PlaceObject2"
+
+    @property
+    def type(self):
+        return TagPlaceObject2.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 3
+
+class TagRemoveObject2(TagRemoveObject):
+    """
+    The RemoveObject2 tag removes the character at the specified depth
+    from the display list.
+    The minimum file format version is SWF 3.
+    """
+    TYPE = 28
+
+    def __init__(self):
+        super(TagRemoveObject2, self).__init__()
+
+    @property
+    def name(self):
+        return "RemoveObject2"
+
+    @property
+    def type(self):
+        return TagRemoveObject2.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 3
+
+    def parse(self, data, length, version=1):
+        self.depth = data.readUI16()
+
+class TagDefineShape3(TagDefineShape2):
+    """
+    DefineShape3 extends the capabilities of DefineShape2 by extending
+    all of the RGB color fields to support RGBA with opacity information.
+    The minimum file format version is SWF 3.
+    """
+    TYPE = 32
+    def __init__(self):
+        super(TagDefineShape3, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineShape3"
+
+    @property
+    def type(self):
+        return TagDefineShape3.TYPE
+
+    @property
+    def level(self):
+        return 3
+
+    @property
+    def version(self):
+        return 3
+
+class TagDefineText2(TagDefineText):
+    """
+    The DefineText tag defines a block of static text. It describes the font,
+    size, color, and exact position of every character in the text object.
+    The minimum file format version is SWF 3.
+    """
+    TYPE = 33
+    def __init__(self):
+        super(TagDefineText2, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineText2"
+
+    @property
+    def type(self):
+        return TagDefineText2.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 3
+
+class TagDefineBitsJPEG3(TagDefineBitsJPEG2):
+    """
+    This tag defines a bitmap character with JPEG compression. This tag
+    extends DefineBitsJPEG2, adding alpha channel (opacity) data.
+    Opacity/transparency information is not a standard feature in JPEG images,
+    so the alpha channel information is encoded separately from the JPEG data,
+    and compressed using the ZLIB standard for compression. The data format
+    used by the ZLIB library is described by Request for Comments (RFCs)
+    documents 1950 to 1952.
+    The data in this tag begins with the JPEG SOI marker 0xFF, 0xD8 and ends
+    with the EOI marker 0xFF, 0xD9. Before version 8 of the SWF file format,
+    SWF files could contain an erroneous header of 0xFF, 0xD9, 0xFF, 0xD8
+    before the JPEG SOI marker.
+    In addition to specifying JPEG data, DefineBitsJPEG2 can also contain
+    PNG image data and non-animated GIF89a image data.
+    - If ImageData begins with the eight bytes 0x89 0x50 0x4E 0x47 0x0D 0x0A 0x1A 0x0A,
+      the ImageData contains PNG data.
+    - If ImageData begins with the six bytes 0x47 0x49 0x46 0x38 0x39 0x61,
+      the ImageData contains GIF89a data.
+    If ImageData contains PNG or GIF89a data, the optional BitmapAlphaData is
+    not supported.
+    The minimum file format version for this tag is SWF 3. The minimum file
+    format version for embedding PNG of GIF89a data is SWF 8.
+    """
+    TYPE = 35
+    def __init__(self):
+        self.bitmapAlphaData = BytesIO()
+        super(TagDefineBitsJPEG3, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineBitsJPEG3"
+
+    @property
+    def type(self):
+        return TagDefineBitsJPEG3.TYPE
+
+    @property
+    def version(self):
+        return 3 if self.bitmapType == BitmapType.JPEG else 8
+
+    @property
+    def level(self):
+        return 3
+
+    def parse(self, data, length, version=1):
+        import zlib
+        self.characterId = data.readUI16()
+        alphaOffset = data.readUI32()
+        self.bitmapAlphaData = BytesIO()
+        self.bitmapData = BytesIO()
+        self.bitmapData.write(data.f.read(alphaOffset))
+        self.bitmapData.seek(0)
+        self.bitmapType = ImageUtils.get_image_type(self.bitmapData)
+        alphaDataSize = length - alphaOffset - 6
+        if alphaDataSize > 0:
+            self.bitmapAlphaData.write(data.f.read(alphaDataSize))
+            self.bitmapAlphaData.seek(0)
+            # decompress zlib encoded bytes
+            zip = zlib.decompressobj()
+            temp = BytesIO()
+            temp.write(zip.decompress(self.bitmapAlphaData.read()))
+            temp.seek(0)
+            self.bitmapAlphaData = temp
+
+class TagDefineBitsLossless2(TagDefineBitsLossless):
+    """
+    DefineBitsLossless2 extends DefineBitsLossless with support for
+    opacity (alpha values). The colormap colors in colormapped images
+    are defined using RGBA values, and direct images store 32-bit
+    ARGB colors for each pixel. The intermediate 15-bit color depth
+    is not available in DefineBitsLossless2.
+    The minimum file format version for this tag is SWF 3.
+    """
+    TYPE = 36
+    def __init__(self):
+        super(TagDefineBitsLossless2, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineBitsLossless2"
+
+    @property
+    def type(self):
+        return TagDefineBitsLossless2.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 3
+
+class TagDefineSprite(SWFTimelineContainer):
+    """
+    The DefineSprite tag defines a sprite character. It consists of
+    a character ID and a frame count, followed by a series of control
+    tags. The sprite is terminated with an End tag.
+    The length specified in the Header reflects the length of the
+    entire DefineSprite tag, including the ControlTags field.
+    Definition tags (such as DefineShape) are not allowed in the
+    DefineSprite tag. All of the characters that control tags refer to
+    in the sprite must be defined in the main body of the file before
+    the sprite is defined.
+    The minimum file format version is SWF 3.
+    """
+    TYPE = 39
+    frameCount = 0
+    def __init__(self):
+        super(TagDefineSprite, self).__init__()
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.frameCount = data.readUI16()
+        self.parse_tags(data, version)
+
+    def get_dependencies(self):
+        s = super(TagDefineSprite, self).get_dependencies()
+        s.add(self.characterId)
+        return s
+
+    @property
+    def name(self):
+        return "DefineSprite"
+
+    @property
+    def type(self):
+        return TagDefineSprite.TYPE
+
+    def __str__(self):
+        s = super(TagDefineSprite, self).__str__() + " " + \
+            "ID: %d" % self.characterId
+        return s
+
+class TagFrameLabel(Tag):
+    """
+    The FrameLabel tag gives the specified Name to the current frame.
+    ActionGoToLabel uses this name to identify the frame.
+    The minimum file format version is SWF 3.
+    """
+    TYPE = 43
+    frameName = ""
+    namedAnchorFlag = False
+    def __init__(self):
+        super(TagFrameLabel, self).__init__()
+
+    @property
+    def name(self):
+        return "FrameLabel"
+
+    @property
+    def type(self):
+        return TagFrameLabel.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 3
+
+    def parse(self, data, length, version=1):
+        start = data.tell()
+        self.frameName = data.readString()
+        if (data.tell() - start) < length:
+            data.readUI8() # Named anchor flag, always 1
+            self.namedAnchorFlag = True
+
+class TagDefineMorphShape(DefinitionTag):
+    """
+    The DefineMorphShape tag defines the start and end states of a morph
+    sequence. A morph object should be displayed with the PlaceObject2 tag,
+    where the ratio field specifies how far the morph has progressed.
+    The minimum file format version is SWF 3.
+    """
+    TYPE = 46
+    def __init__(self):
+        self._morphFillStyles = []
+        self._morphLineStyles = []
+        super(TagDefineMorphShape, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineMorphShape"
+
+    @property
+    def type(self):
+        return TagDefineMorphShape.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 3
+
+    @property
+    def morph_fill_styles(self):
+        """ Return list of SWFMorphFillStyle """
+        return self._morphFillStyles
+
+    @property
+    def morph_line_styles(self):
+        """ Return list of SWFMorphLineStyle """
+        return self._morphLineStyles
+
+    def parse(self, data, length, version=1):
+        self._morphFillStyles = []
+        self._morphLineStyles = []
+        self.characterId = data.readUI16()
+        self.startBounds = data.readRECT()
+        self.endBounds = data.readRECT()
+        offset = data.readUI32()
+
+        self._morphFillStyles = data.readMORPHFILLSTYLEARRAY()
+        self._morphLineStyles = data.readMORPHLINESTYLEARRAY(version = 1)
+        self.startEdges = data.readSHAPE();
+        self.endEdges = data.readSHAPE();
+
+class TagDefineFont2(TagDefineFont):
+    TYPE= 48
+    def __init__(self):
+        self.glyphShapeTable = []
+        super(TagDefineFont2, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineFont2"
+
+    @property
+    def type(self):
+        return TagDefineFont2.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 3
+
+    @property
+    def unitDivisor(self):
+        return 20
+
+    def parse(self, data, length, version=1):
+        self.glyphShapeTable = []
+        self.codeTable = []
+        self.fontAdvanceTable = []
+        self.fontBoundsTable = []
+        self.fontKerningTable = []
+
+        self.characterId = data.readUI16()
+
+        flags = data.readUI8()
+
+        self.hasLayout = ((flags & 0x80) != 0)
+        self.shiftJIS = ((flags & 0x40) != 0)
+        self.smallText = ((flags & 0x20) != 0)
+        self.ansi = ((flags & 0x10) != 0)
+        self.wideOffsets = ((flags & 0x08) != 0)
+        self.wideCodes = ((flags & 0x04) != 0)
+        self.italic = ((flags & 0x02) != 0)
+        self.bold = ((flags & 0x01) != 0)
+        self.languageCode = data.readLANGCODE()
+
+        fontNameLen = data.readUI8()
+        fontNameRaw = BytesIO()
+        fontNameRaw.write(data.f.read(fontNameLen))
+        fontNameRaw.seek(0)
+        self.fontName = fontNameRaw.read()
+
+        numGlyphs = data.readUI16()
+        numSkip = 2 if self.wideOffsets else 1
+        # don't # Skip offsets. We don't need them.
+        # Adobe Flash Player works in this way
+
+        startOfOffsetTable = data.f.tell()
+        offsetTable = []
+        for i in range(0, numGlyphs):
+            offsetTable.append(data.readUI32() if self.wideOffsets else data.readUI16())
+
+        codeTableOffset = data.readUI32() if self.wideOffsets else data.readUI16()
+        for i in range(0, numGlyphs):
+            data.f.seek(startOfOffsetTable + offsetTable[i])
+            self.glyphShapeTable.append(data.readSHAPE(self.unitDivisor))
+        data.f.seek(startOfOffsetTable + codeTableOffset)
+        for i in range(0, numGlyphs):
+            self.codeTable.append(data.readUI16() if self.wideCodes else data.readUI8())
+
+        if self.hasLayout:
+            self.ascent = data.readSI16()
+            self.descent = data.readSI16()
+            self.leading = data.readSI16()
+            for i in range(0, numGlyphs):
+                self.fontAdvanceTable.append(data.readSI16())
+            for i in range(0, numGlyphs):
+                self.fontBoundsTable.append(data.readRECT())
+            kerningCount = data.readUI16()
+            for i in range(0, kerningCount):
+                self.fontKerningTable.append(data.readKERNINGRECORD(self.wideCodes))
+
+class TagFileAttributes(Tag):
+    """
+    The FileAttributes tag defines characteristics of the SWF file. This tag
+    is required for SWF 8 and later and must be the first tag in the SWF file.
+    Additionally, the FileAttributes tag can optionally be included in all SWF
+    file versions.
+    The HasMetadata flag identifies whether the SWF file contains the Metadata
+    tag. Flash Player does not care about this bit field or the related tag but
+    it is useful for search engines.
+    The UseNetwork flag signifies whether Flash Player should grant the SWF file
+    local or network file access if the SWF file is loaded locally. The default
+    behavior is to allow local SWF files to interact with local files only, and
+    not with the network. However, by setting the UseNetwork flag, the local SWF
+    can forfeit its local file system access in exchange for access to the
+    network. Any version of SWF can use the UseNetwork flag to set the file
+    access for locally loaded SWF files that are running in Flash Player 8 or later.
+    """
+    TYPE = 69
+    def __init__(self):
+        super(TagFileAttributes, self).__init__()
+
+    @property
+    def name(self):
+        return "FileAttributes"
+
+    @property
+    def type(self):
+        return TagFileAttributes.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 8
+
+    def parse(self, data, length, version=1):
+        flags = data.readUI8()
+        self.useDirectBlit = ((flags & 0x40) != 0)
+        self.useGPU = ((flags & 0x20) != 0)
+        self.hasMetadata = ((flags & 0x10) != 0)
+        self.actionscript3 = ((flags & 0x08) != 0)
+        self.useNetwork = ((flags & 0x01) != 0)
+        data.skip_bytes(3)
+
+    def __str__(self):
+        s = super(TagFileAttributes, self).__str__() + \
+            " useDirectBlit: %d, " % self.useDirectBlit + \
+            "useGPU: %d, " % self.useGPU + \
+            "hasMetadata: %d, " % self.hasMetadata + \
+            "actionscript3: %d, " % self.actionscript3 + \
+            "useNetwork: %d" % self.useNetwork
+        return s
+
+class TagPlaceObject3(TagPlaceObject2):
+    TYPE = 70
+    def __init__(self):
+        super(TagPlaceObject3, self).__init__()
+
+    def parse(self, data, length, version=1):
+        flags = data.readUI8()
+        self.hasClipActions = ((flags & 0x80) != 0)
+        self.hasClipDepth = ((flags & 0x40) != 0)
+        self.hasName = ((flags & 0x20) != 0)
+        self.hasRatio = ((flags & 0x10) != 0)
+        self.hasColorTransform = ((flags & 0x08) != 0)
+        self.hasMatrix = ((flags & 0x04) != 0)
+        self.hasCharacter = ((flags & 0x02) != 0)
+        self.hasMove = ((flags & 0x01) != 0)
+        flags2 = data.readUI8();
+        self.hasImage = ((flags2 & 0x10) != 0)
+        self.hasClassName = ((flags2 & 0x08) != 0)
+        self.hasCacheAsBitmap = ((flags2 & 0x04) != 0)
+        self.hasBlendMode = ((flags2 & 0x2) != 0)
+        self.hasFilterList = ((flags2 & 0x1) != 0)
+        self.depth = data.readUI16()
+
+        if self.hasClassName:
+            self.className = data.readString()
+        if self.hasCharacter:
+            self.characterId = data.readUI16()
+        if self.hasMatrix:
+            self.matrix = data.readMATRIX()
+        if self.hasColorTransform:
+            self.colorTransform = data.readCXFORMWITHALPHA()
+        if self.hasRatio:
+            self.ratio = data.readUI16()
+        if self.hasName:
+            self.instanceName = data.readString()
+        if self.hasClipDepth:
+            self.clipDepth = data.readUI16();
+        if self.hasFilterList:
+            numberOfFilters = data.readUI8()
+            for i in range(0, numberOfFilters):
+                self._surfaceFilterList.append(data.readFILTER())
+        if self.hasBlendMode:
+            self.blendMode = data.readUI8()
+        if self.hasCacheAsBitmap:
+            self.bitmapCache = data.readUI8()
+        if self.hasClipActions:
+            self.clipActions = data.readCLIPACTIONS(version)
+            #raise Exception("PlaceObject3: ClipActions not yet implemented!")
+
+    @property
+    def name(self):
+        return "PlaceObject3"
+
+    @property
+    def type(self):
+        return TagPlaceObject3.TYPE
+
+class TagDefineFontAlignZones(Tag):
+    TYPE = 73
+    def __init__(self):
+        super(TagDefineFontAlignZones, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineFontAlignZones"
+
+    @property
+    def type(self):
+        return TagDefineFontAlignZones.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 8
+
+    def parse(self, data, length, version=1):
+        self.zoneTable = []
+
+        self.fontId = data.readUI16()
+        self.csmTableHint = (data.readUI8() >> 6)
+
+        recordsEndPos = data.tell() + length - 3;
+        while data.tell() < recordsEndPos:
+            self.zoneTable.append(data.readZONERECORD())
+
+class TagCSMTextSettings(Tag):
+    TYPE = 74
+    def __init__(self):
+        super(TagCSMTextSettings, self).__init__()
+
+    @property
+    def name(self):
+        return "CSMTextSettings"
+
+    @property
+    def type(self):
+        return TagCSMTextSettings.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 8
+
+    def parse(self, data, length, version=1):
+        self.textId = data.readUI16()
+        self.useFlashType = data.readUB(2)
+        self.gridFit = data.readUB(3);
+        data.readUB(3) # reserved, always 0
+        self.thickness = data.readFIXED()
+        self.sharpness = data.readFIXED()
+        data.readUI8() # reserved, always 0
+
+class TagDefineFont3(TagDefineFont2):
+    TYPE = 75
+    def __init__(self):
+        super(TagDefineFont3, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineFont3"
+
+    @property
+    def type(self):
+        return TagDefineFont3.TYPE
+
+    @property
+    def level(self):
+        return 2
+
+    @property
+    def version(self):
+        return 8
+
+class TagSymbolClass(Tag):
+    TYPE = 76
+    def __init__(self):
+        self.symbols = []
+        super(TagSymbolClass, self).__init__()
+
+    @property
+    def name(self):
+        return "SymbolClass"
+
+    @property
+    def type(self):
+        return TagSymbolClass.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 9 # educated guess (not specified in SWF10 spec)
+
+    def parse(self, data, length, version=1):
+        self.symbols = []
+        numSymbols = data.readUI16()
+        for i in range(0, numSymbols):
+            self.symbols.append(data.readSYMBOL())
+
+class TagMetadata(Tag):
+    TYPE = 77
+    def __init__(self):
+        super(TagMetadata, self).__init__()
+
+    @property
+    def name(self):
+        return "Metadata"
+
+    @property
+    def type(self):
+        return TagMetadata.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    def parse(self, data, length, version=1):
+        self.xmlString = data.readString()
+
+    def __str__(self):
+        s = super(TagMetadata, self).__str__()
+        s += " xml: %r" % self.xmlString
+        return s
+
+class TagDoABC(Tag):
+    TYPE = 82
+    def __init__(self):
+        super(TagDoABC, self).__init__()
+
+    @property
+    def name(self):
+        return "DoABC"
+
+    @property
+    def type(self):
+        return TagDoABC.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 9
+
+    def parse(self, data, length, version=1):
+        pos = data.tell()
+        flags = data.readUI32()
+        self.lazyInitializeFlag = ((flags & 0x01) != 0)
+        self.abcName = data.readString()
+        self.bytes = data.f.read(length - (data.tell() - pos))
+
+class TagDefineShape4(TagDefineShape3):
+    TYPE = 83
+    def __init__(self):
+        super(TagDefineShape4, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineShape4"
+
+    @property
+    def type(self):
+        return TagDefineShape4.TYPE
+
+    @property
+    def level(self):
+        return 4
+
+    @property
+    def version(self):
+        return 8
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self._shape_bounds = data.readRECT()
+        self.edge_bounds = data.readRECT()
+        flags = data.readUI8()
+        self.uses_fillwinding_rule = ((flags & 0x04) != 0)
+        self.uses_non_scaling_strokes = ((flags & 0x02) != 0)
+        self.uses_scaling_strokes = ((flags & 0x01) != 0)
+        self._shapes = data.readSHAPEWITHSTYLE(self.level)
+
+class TagDefineSceneAndFrameLabelData(Tag):
+    TYPE = 86
+    def __init__(self):
+        self.scenes = []
+        self.frameLabels = []
+        super(TagDefineSceneAndFrameLabelData, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineSceneAndFrameLabelData"
+
+    @property
+    def type(self):
+        return TagDefineSceneAndFrameLabelData.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 9
+
+    def parse(self, data, length, version=1):
+        self.sceneCount = data.readEncodedU32()
+
+        if self.sceneCount >= 0x80000000:
+            #print "WARNING: Negative sceneCount value: %x found!. SWF file exploiting CVE-2007-0071?" % self.sceneCount
+            return
+
+        self.scenes = []
+        self.frameLabels = []
+        for i in range(0, self.sceneCount):
+            sceneOffset = data.readEncodedU32()
+            sceneName = data.readString()
+            self.scenes.append(SWFScene(sceneOffset, sceneName))
+
+        frameLabelCount = data.readEncodedU32()
+        for i in range(0, frameLabelCount):
+            frameNumber = data.readEncodedU32();
+            frameLabel = data.readString();
+            self.frameLabels.append(SWFFrameLabel(frameNumber, frameLabel))
+
+class TagDefineBinaryData(DefinitionTag):
+    """
+	The DefineBinaryData tag permits arbitrary binary data to be embedded in a SWF file. DefineBinaryData is a definition tag, like DefineShape and DefineSprite. It associates a blob of binary data with a standard SWF 16-bit character ID. The character ID is entered into the SWF file's character dictionary. DefineBinaryData is intended to be used in conjunction with the SymbolClass tag. The SymbolClass tag can be used to associate a DefineBinaryData tag with an AS3 class definition. The AS3 class must be a subclass of ByteArray. When the class is instantiated, it will be populated automatically with the contents of the binary data resource.
+    """
+    TYPE = 87
+    def __init__(self):
+        super(TagDefineBinaryData, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineBinaryData"
+
+    @property
+    def type(self):
+        return TagDefineBinaryData.TYPE
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.reserved = data.readUI32()
+        self.data = data.read(length - 6)
+
+class TagDefineFontName(Tag):
+    TYPE = 88
+    def __init__(self):
+        super(TagDefineFontName, self).__init__()
+
+    @property
+    def name(self):
+        return "DefineFontName"
+
+    @property
+    def type(self):
+        return TagDefineFontName.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 9
+
+    def get_dependencies(self):
+        s = super(TagDefineFontName, self).get_dependencies()
+        s.add(self.fontId)
+        return s
+
+    def parse(self, data, length, version=1):
+        self.fontId = data.readUI16()
+        self.fontName = data.readString()
+        self.fontCopyright = data.readString()
+
+class TagDefineSound(Tag):
+    TYPE = 14
+    def __init__(self):
+        super(TagDefineSound, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineSound"
+
+    @property
+    def type(self):
+        return TagDefineSound.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    def parse(self, data, length, version=1):
+        assert length > 7
+        self.soundId = data.readUI16()
+        self.soundFormat = data.readUB(4)
+        self.soundRate = data.readUB(2)
+        self.soundSampleSize = data.readUB(1)
+        self.soundChannels = data.readUB(1)
+        self.soundSamples = data.readUI32()
+        # used 2 + 1 + 4 bytes here
+        self.soundData = BytesIO(data.read(length - 7))
+
+    def __str__(self):
+        s = super(TagDefineSound, self).__str__()
+        s += " soundFormat: %s" % AudioCodec.tostring(self.soundFormat)
+        s += " soundRate: %s" % AudioSampleRate.tostring(self.soundRate)
+        s += " soundSampleSize: %s" % AudioSampleSize.tostring(self.soundSampleSize)
+        s += " soundChannels: %s" % AudioChannels.tostring(self.soundChannels)
+        return s
+
+class TagStartSound(Tag):
+    TYPE = 15
+    def __init__(self):
+        super(TagStartSound, self).__init__()
+
+    @property
+    def name(self):
+        return "TagStartSound"
+
+    @property
+    def type(self):
+        return TagStartSound.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    def parse(self, data, length, version=1):
+        self.soundId = data.readUI16()
+        self.soundInfo = data.readSOUNDINFO()
+
+class TagStartSound2(Tag):
+    TYPE = 89
+    def __init__(self):
+        super(TagStartSound2, self).__init__()
+
+    @property
+    def name(self):
+        return "TagStartSound2"
+
+    @property
+    def type(self):
+        return TagStartSound2.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 9
+
+    def parse(self, data, length, version=1):
+        self.soundClassName = data.readString()
+        self.soundInfo = data.readSOUNDINFO()
+
+class TagSoundStreamHead(Tag):
+    TYPE = 18
+    def __init__(self):
+        super(TagSoundStreamHead, self).__init__()
+
+    @property
+    def name(self):
+        return "TagSoundStreamHead"
+
+    @property
+    def type(self):
+        return TagSoundStreamHead.TYPE
+
+    @property
+    def level(self):
+        return 1
+
+    @property
+    def version(self):
+        return 1
+
+    def parse(self, data, length, version=1):
+        # byte 1
+        self.reserved0 = data.readUB(4)
+        self.playbackRate = data.readUB(2)
+        self.playbackSampleSize = data.readUB(1)
+        self.playbackChannels = data.readUB(1)
+
+        # byte 2
+        self.soundFormat = data.readUB(4)
+        self.soundRate = data.readUB(2)
+        self.soundSampleSize = data.readUB(1)
+        self.soundChannels = data.readUB(1)
+
+        self.samples = data.readUI16()
+        self.latencySeek = data.readSI16() if self.soundFormat == AudioCodec.MP3 else None
+        hdr = 6 if self.soundFormat == AudioCodec.MP3 else 4
+        assert hdr == length
+
+    def __str__(self):
+        s = super(TagSoundStreamHead, self).__str__()
+        s += " playbackRate: %s" % AudioSampleRate.tostring(self.playbackRate)
+        s += " playbackSampleSize: %s" % AudioSampleSize.tostring(self.playbackSampleSize)
+        s += " playbackChannels: %s" % AudioChannels.tostring(self.playbackChannels)
+        s += " soundFormat: %s" % AudioCodec.tostring(self.soundFormat)
+        s += " soundRate: %s" % AudioSampleRate.tostring(self.soundRate)
+        s += " soundSampleSize: %s" % AudioSampleSize.tostring(self.soundSampleSize)
+        s += " soundChannels: %s" % AudioChannels.tostring(self.soundChannels)
+        return s
+
+class TagSoundStreamHead2(TagSoundStreamHead):
+    """
+    The SoundStreamHead2 tag is identical to the SoundStreamHead tag, except it allows
+    different values for StreamSoundCompression and StreamSoundSize (SWF 3 file format).
+    """
+    TYPE = 45
+
+    def __init__(self):
+        super(TagSoundStreamHead2, self).__init__()
+
+    @property
+    def name(self):
+        return "TagSoundStreamHead2"
+
+    @property
+    def type(self):
+        return TagSoundStreamHead2.TYPE
+
+class TagSoundStreamBlock(Tag):
+    """
+    The SoundStreamHead2 tag is identical to the SoundStreamHead tag, except it allows
+    different values for StreamSoundCompression and StreamSoundSize (SWF 3 file format).
+    """
+    TYPE = 19
+
+    def __init__(self):
+        super(TagSoundStreamBlock, self).__init__()
+
+    @property
+    def name(self):
+        return "TagSoundStreamBlock"
+
+    @property
+    def type(self):
+        return TagSoundStreamBlock.TYPE
+
+    def parse(self, data, length, version=1):
+        # unfortunately we can't see our associated SoundStreamHead from here,
+        # so just stash the data
+        self.data = BytesIO(data.read(length))
+
+    def complete_parse_with_header(self, head):
+        stream = SWFStream(self.data)
+        if head.soundFormat in (AudioCodec.UncompressedNativeEndian,
+                                AudioCodec.UncompressedLittleEndian):
+            pass # data is enough
+        elif head.soundFormat == AudioCodec.MP3:
+            self.sampleCount = stream.readUI16()
+            self.seekSize = stream.readSI16()
+            self.mpegFrames = stream.read()
+
+class TagDefineBinaryData(DefinitionTag):
+    """
+    The DefineBinaryData tag permits arbitrary binary data to be embedded in a SWF file.
+    DefineBinaryData is a definition tag, like DefineShape and DefineSprite. It associates a blob
+    of binary data with a standard SWF 16-bit character ID. The character ID is entered into the
+    SWF file's character dictionary.
+    """
+    TYPE = 87
+
+    def __init__(self):
+        super(TagDefineBinaryData, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineBinaryData"
+
+    @property
+    def type(self):
+        return TagDefineBinaryData.TYPE
+
+    def parse(self, data, length, version=1):
+        assert length >= 6
+        self.characterId = data.readUI16()
+        self.reserved = data.readUI32()
+        self.data = data.read(length - 4 - 2)
+
+class TagProductInfo(Tag):
+    """
+    Undocumented in SWF10.
+    """
+    TYPE = 41
+
+    def __init__(self):
+        super(TagProductInfo, self).__init__()
+
+    @property
+    def name(self):
+        return "TagProductInfo"
+
+    @property
+    def type(self):
+        return TagProductInfo.TYPE
+
+    def parse(self, data, length, version=1):
+        self.product = data.readUI32()
+        self.edition = data.readUI32()
+        self.majorVersion, self.minorVersion = data.readUI8(), data.readUI8()
+        self.build = data.readUI64()
+        self.compileTime = data.readUI64()
+
+    def __str__(self):
+        s = super(TagProductInfo, self).__str__()
+        s += " product: %s" % ProductKind.tostring(self.product)
+        s += " edition: %s" % ProductEdition.tostring(self.edition)
+        s += " major.minor.build: %d.%d.%d" % (self.majorVersion, self.minorVersion, self.build)
+        s += " compileTime: %d" % (self.compileTime)
+        return s
+
+class TagScriptLimits(Tag):
+    """
+    The ScriptLimits tag includes two fields that can be used to override the default settings for
+    maximum recursion depth and ActionScript time-out: MaxRecursionDepth and
+    ScriptTimeoutSeconds.
+    """
+    TYPE = 65
+
+    def __init__(self):
+        super(TagScriptLimits, self).__init__()
+
+    @property
+    def name(self):
+        return "TagScriptLimits"
+
+    @property
+    def type(self):
+        return TagScriptLimits.TYPE
+
+    def parse(self, data, length, version=1):
+        self.maxRecursionDepth = data.readUI16()
+        self.scriptTimeoutSeconds = data.readUI16()
+
+    def __str__(self):
+        s = super(TagScriptLimits, self).__str__()
+        s += " maxRecursionDepth: %s" % self.maxRecursionDepth
+        s += " scriptTimeoutSeconds: %s" % self.scriptTimeoutSeconds
+        return s
+
+class TagDebugID(Tag):
+    """
+    Undocumented in SWF10.  Some kind of GUID.
+    """
+    TYPE = 63
+
+    def __init__(self):
+        super(TagDebugID, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDebugID"
+
+    @property
+    def type(self):
+        return TagDebugID.TYPE
+
+    def parse(self, data, length, version=1):
+        self.guid = data.read(16)
+
+class TagExportAssets(Tag):
+    """
+    The ExportAssets tag makes portions of a SWF file available for import by other SWF files
+    """
+    TYPE = 56
+
+    def __init__(self):
+        super(TagExportAssets, self).__init__()
+
+    @property
+    def name(self):
+        return "TagExportAssets"
+
+    @property
+    def version(self):
+        return 5
+
+    @property
+    def type(self):
+        return TagExportAssets.TYPE
+
+    def parse(self, data, length, version=1):
+        self.count = data.readUI16()
+        self.exports = [data.readEXPORT() for i in range(self.count)]
+
+    def __str__(self):
+        s = super(TagExportAssets, self).__str__()
+        s += " exports: %s" % self.exports
+        return s
+
+class TagProtect(Tag):
+    """
+    The Protect tag marks a file as not importable for editing in an authoring environment. If the
+    Protect tag contains no data (tag length = 0), the SWF file cannot be imported. If this tag is
+    present in the file, any authoring tool should prevent the file from loading for editing.
+    """
+    TYPE = 24
+
+    def __init__(self):
+        super(TagProtect, self).__init__()
+        self.password = None
+
+    @property
+    def name(self):
+        return "TagProtect"
+
+    @property
+    def version(self):
+        return 2 if self.password is None else 5
+
+    @property
+    def type(self):
+        return TagProtect.TYPE
+
+    def parse(self, data, length, version=1):
+        if length:
+            self.password = data.readString()
+        else:
+            self.password = None
+
+    def __str__(self):
+        s = super(TagProtect, self).__str__()
+        s += " password: %r" % self.password
+        return s
+
+class TagEnableDebugger(Tag):
+    """
+    The EnableDebugger tag enables debugging. The password in the EnableDebugger tag is
+    encrypted by using the MD5 algorithm, in the same way as the Protect tag.
+    """
+    TYPE = 58
+
+    def __init__(self):
+        super(TagEnableDebugger, self).__init__()
+
+    @property
+    def name(self):
+        return "TagEnableDebugger"
+
+    @property
+    def version(self):
+        return 5
+
+    @property
+    def type(self):
+        return TagEnableDebugger.TYPE
+
+    def parse(self, data, length, version=1):
+        self.password = data.readString()
+
+    def __str__(self):
+        s = super(TagEnableDebugger, self).__str__()
+        s += " password: %r" % self.password
+        return s
+
+class TagEnableDebugger2(Tag):
+    """
+    The EnableDebugger2 tag enables debugging. The Password field is encrypted by using the
+    MD5 algorithm, in the same way as the Protect tag.
+    """
+    TYPE = 64
+
+    def __init__(self):
+        super(TagEnableDebugger2, self).__init__()
+
+    @property
+    def name(self):
+        return "TagEnableDebugger2"
+
+    @property
+    def version(self):
+        return 6
+
+    @property
+    def type(self):
+        return TagEnableDebugger2.TYPE
+
+    def parse(self, data, length, version=1):
+        self.reserved0 = data.readUI16()
+        self.password = data.readString()
+
+    def __str__(self):
+        s = super(TagEnableDebugger2, self).__str__()
+        s += " password: %r" % self.password
+        return s
+
+class TagDoInitAction(Tag):
+    """
+    The DoInitAction tag is similar to the DoAction tag: it defines a series of bytecodes to be
+    executed. However, the actions defined with DoInitAction are executed earlier than the usual
+    DoAction actions, and are executed only once.
+    """
+    TYPE = 59
+
+    def __init__(self):
+        super(TagDoInitAction, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDoInitAction"
+
+    @property
+    def version(self):
+        return 6
+
+    @property
+    def type(self):
+        return TagDoInitAction.TYPE
+
+    def get_dependencies(self):
+        s = super(TagDoInitAction, self).get_dependencies()
+        s.add(self.spriteId)
+        return s
+
+    def parse(self, data, length, version=1):
+        self.spriteId = data.readUI16()
+        self.actions = data.readACTIONRECORDs()
+
+class TagDefineEditText(DefinitionTag):
+    """
+    The DefineEditText tag defines a dynamic text object, or text field.
+
+    A text field is associated with an ActionScript variable name where the contents of the text
+    field are stored. The SWF file can read and write the contents of the variable, which is always
+    kept in sync with the text being displayed. If the ReadOnly flag is not set, users may change
+    the value of a text field interactively
+    """
+    TYPE = 37
+
+    def __init__(self):
+        super(TagDefineEditText, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineEditText"
+
+    @property
+    def type(self):
+        return TagDefineEditText.TYPE
+
+    def get_dependencies(self):
+        s = super(TagDefineEditText, self).get_dependencies()
+        s.add(self.fontId) if self.hasFont else None
+        return s
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.bounds = data.readRECT()
+
+        # flags
+        self.hasText = data.readUB(1) == 1
+        self.wordWrap = data.readUB(1) == 1
+        self.multiline = data.readUB(1) == 1
+        self.password = data.readUB(1) == 1
+
+        self.readOnly = data.readUB(1) == 1
+        self.hasTextColor = data.readUB(1) == 1
+        self.hasMaxLength = data.readUB(1) == 1
+        self.hasFont = data.readUB(1) == 1
+
+        self.hasFontClass = data.readUB(1) == 1
+        self.autoSize = data.readUB(1) == 1
+        self.hasLayout = data.readUB(1) == 1
+        self.noSelect = data.readUB(1) == 1
+
+        self.border = data.readUB(1) == 1
+        self.wasStatic = data.readUB(1) == 1
+        self.html = data.readUB(1) == 1
+        self.useOutlines = data.readUB(1) == 1
+
+        # values
+        self.fontId = data.readUI16() if self.hasFont else None
+        self.fontClass = data.readString() if self.hasFontClass else None
+        self.fontHeight = data.readUI16() if self.hasFont else None
+        self.textColor = data.readRGBA() if self.hasTextColor else None
+        self.maxLength = data.readUI16() if self.hasMaxLength else None
+
+        self.align = data.readUI8() if self.hasLayout else None
+        self.leftMargin = data.readUI16() if self.hasLayout else None
+        self.rightMargin = data.readUI16() if self.hasLayout else None
+        self.indent = data.readUI16() if self.hasLayout else None
+        self.leading = data.readUI16() if self.hasLayout else None
+
+        # backend info
+        self.variableName = data.readString()
+        self.initialText = data.readString() if self.hasText else None
+
+class TagDefineButton(DefinitionTag):
+    """
+    The DefineButton tag defines a button character for later use by control tags such as
+    PlaceObject.
+    """
+    TYPE = 7
+
+    def __init__(self):
+        super(TagDefineButton, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineButton"
+
+    @property
+    def type(self):
+        return TagDefineButton.TYPE
+
+    def get_dependencies(self):
+        s = super(TagDefineButton, self).get_dependencies()
+        for b in self.buttonCharacters:
+            s.update(b.get_dependencies())
+        return s
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.buttonCharacters = data.readBUTTONRECORDs(version = 1)
+        self.buttonActions = data.readACTIONRECORDs()
+
+class TagDefineButton2(DefinitionTag):
+    """
+    DefineButton2 extends the capabilities of DefineButton by allowing any state transition to
+    trigger actions.
+    """
+    TYPE = 34
+
+    def __init__(self):
+        super(TagDefineButton2, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineButton2"
+
+    @property
+    def type(self):
+        return TagDefineButton2.TYPE
+
+    def get_dependencies(self):
+        s = super(TagDefineButton2, self).get_dependencies()
+        for b in self.buttonCharacters:
+            s.update(b.get_dependencies())
+        return s
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.reservedFlags = data.readUB(7)
+        self.trackAsMenu = data.readUB(1) == 1
+        offs = data.tell()
+        self.actionOffset = data.readUI16()
+        self.buttonCharacters = data.readBUTTONRECORDs(version = 2)
+
+        if self.actionOffset:
+            # if we have actions, seek to the first one
+            data.seek(offs + self.actionOffset)
+            self.buttonActions = data.readBUTTONCONDACTIONSs()
+
+class TagDefineButtonSound(Tag):
+    """
+    The DefineButtonSound tag defines which sounds (if any) are played on state transitions.
+    """
+    TYPE = 17
+
+    def __init__(self):
+        super(TagDefineButtonSound, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineButtonSound"
+
+    @property
+    def type(self):
+        return TagDefineButtonSound.TYPE
+
+    @property
+    def version(self):
+        return 2
+
+    def parse(self, data, length, version=1):
+        self.buttonId = data.readUI16()
+
+        for event in 'OverUpToIdle IdleToOverUp OverUpToOverDown OverDownToOverUp'.split():
+            soundId = data.readUI16()
+            setattr(self, 'soundOn' + event, soundId)
+            soundInfo = data.readSOUNDINFO() if soundId else None
+            setattr(self, 'soundInfoOn' + event, soundInfo)
+
+class TagDefineScalingGrid(Tag):
+    """
+    The DefineScalingGrid tag introduces the concept of 9-slice scaling, which allows
+    component-style scaling to be applied to a sprite or button character.
+    """
+    TYPE = 78
+
+    def __init__(self):
+        super(TagDefineScalingGrid, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineScalingGrid"
+
+    @property
+    def type(self):
+        return TagDefineScalingGrid.TYPE
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.splitter = data.readRECT()
+
+class TagDefineVideoStream(DefinitionTag):
+    """
+    DefineVideoStream defines a video character that can later be placed on the display list.
+    """
+    TYPE = 60
+
+    def __init__(self):
+        super(TagDefineVideoStream, self).__init__()
+
+    @property
+    def name(self):
+        return "TagDefineVideoStream"
+
+    @property
+    def type(self):
+        return TagDefineVideoStream.TYPE
+
+    def parse(self, data, length, version=1):
+        self.characterId = data.readUI16()
+        self.numFrames = data.readUI16()
+        self.width = data.readUI16()
+        self.height = data.readUI16()
+        reserved0 = data.readUB(4)
+        self.videoDeblocking = data.readUB(3)
+        self.videoSmoothing = data.readUB(1)
+        self.codec = data.readUI8()
+
+class TagVideoFrame(Tag):
+    """
+    VideoFrame provides a single frame of video data for a video character that is already defined
+    with DefineVideoStream.
+    """
+    TYPE = 61
+
+    def __init__(self):
+        super(TagVideoFrame, self).__init__()
+
+    @property
+    def name(self):
+        return "TagVideoFrame"
+
+    @property
+    def type(self):
+        return TagVideoFrame.TYPE
+
+    def parse(self, data, length, version=1):
+        self.streamId = data.readUI16()
+        self.frameNumber = data.readUI16()
+        self.videoData = data.read(length - 4)
+
+class TagDefineMorphShape2(TagDefineMorphShape):
+    """
+    The DefineMorphShape2 tag extends the capabilities of DefineMorphShape by using a new
+    morph line style record in the morph shape. MORPHLINESTYLE2 allows the use of new
+    types of joins and caps as well as scaling options and the ability to fill the strokes of the morph
+    shape.
+    """
+    TYPE = 84
+
+    @property
+    def name(self):
+        return "TagDefineMorphShape2"
+
+    @property
+    def type(self):
+        return TagDefineMorphShape2.TYPE
+
+    @property
+    def version(self):
+        return 8
+
+    def get_dependencies(self):
+        s = super(TagDefineMorphShape2, self).get_dependencies()
+        s.update(self.startEdges.get_dependencies())
+        s.update(self.endEdges.get_dependencies())
+        return s
+
+    def parse(self, data, length, version=1):
+        self._morphFillStyles = []
+        self._morphLineStyles = []
+        self.characterId = data.readUI16()
+
+        self.startBounds = data.readRECT()
+        self.endBounds = data.readRECT()
+        self.startEdgeBounds = data.readRECT()
+        self.endEdgeBounds = data.readRECT()
+
+        self.reserved0 = data.readUB(6)
+        self.usesNonScalingStrokes = data.readUB(1) == 1
+        self.usesScalingStrokes = data.readUB(1) == 1
+
+        offset = data.readUI32()
+        self._morphFillStyles = data.readMORPHFILLSTYLEARRAY()
+        self._morphLineStyles = data.readMORPHLINESTYLEARRAY(version = 2)
+
+        self.startEdges = data.readSHAPE();
+        self.endEdges = data.readSHAPE();
+
+if __name__ == '__main__':
+    # some table checks
+    for x in range(256):
+        y = TagFactory.create(x)
+        if y:
+            assert y.type == x, y.name + ' is misnamed'
+
+    for k, v in globals().items():
+        if k.startswith('Tag') and hasattr(v, 'TYPE'):
+            y = TagFactory.create(v.TYPE)
+            if y == None:
+                #print v.__name__, 'missing', 'for', v.TYPE
+                pass

+ 55 - 0
format_convert/swf/utils.py

@@ -0,0 +1,55 @@
+from .consts import BitmapType
+import math
+
+class NumberUtils(object):
+    @classmethod
+    def round_pixels_20(cls, pixels):
+        return round(pixels * 100) / 100
+    @classmethod
+    def round_pixels_400(cls, pixels):
+        return round(pixels * 10000) / 10000
+ 
+class ColorUtils(object):
+    @classmethod
+    def alpha(cls, color):
+        return int(color >> 24) / 255.0
+    
+    @classmethod
+    def rgb(cls, color):
+        return (color & 0xffffff)
+    
+    @classmethod
+    def to_rgb_string(cls, color):
+        c = "%x" % color
+        while len(c) < 6: c = "0" + c
+        return "#"+c
+        
+class ImageUtils(object):
+    @classmethod
+    def get_image_size(cls, data):
+        pass
+        
+    @classmethod
+    def get_image_type(cls, data):
+        pos = data.tell()
+        image_type = 0
+        data.seek(0, 2) # moves file pointer to final position
+        if data.tell() > 8:
+            data.seek(0)
+            b0 = ord(data.read(1))
+            b1 = ord(data.read(1))
+            b2 = ord(data.read(1))
+            b3 = ord(data.read(1))
+            b4 = ord(data.read(1))
+            b5 = ord(data.read(1))
+            b6 = ord(data.read(1))
+            b7 = ord(data.read(1))
+            if b0 == 0xff and (b1 == 0xd8 or 1 == 0xd9):
+                image_type = BitmapType.JPEG
+            elif b0 == 0x89 and b1 == 0x50 and b2 == 0x4e and b3 == 0x47 and \
+                b4 == 0x0d and b5 == 0x0a and b6 == 0x1a and b7 == 0x0a:
+                image_type = BitmapType.PNG
+            elif b0 == 0x47 and b1 == 0x49 and b2 == 0x46 and b3 == 0x38 and b4 == 0x39 and b5 == 0x61:
+                image_type = BitmapType.GIF89A
+        data.seek(pos)
+        return image_type

+ 278 - 0
format_convert/table_correct.py

@@ -0,0 +1,278 @@
+import math
+import cmath
+import traceback
+
+import numpy as np
+import cv2
+
+
+# 图片旋转
+from format_convert.judge_platform import get_platform
+
+
+def rotate_bound(image, angle):
+    try:
+        # 获取宽高
+        (h, w) = image.shape[:2]
+        (cX, cY) = (w // 2, h // 2)
+
+        # 提取旋转矩阵 sin cos
+        M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
+        cos = np.abs(M[0, 0])
+        sin = np.abs(M[0, 1])
+
+        # 计算图像的新边界尺寸
+        nW = int((h * sin) + (w * cos))
+        #     nH = int((h * cos) + (w * sin))
+        nH = h
+
+        # 调整旋转矩阵
+        M[0, 2] += (nW / 2) - cX
+        M[1, 2] += (nH / 2) - cY
+
+        return cv2.warpAffine(image, M, (nW, nH),flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
+    except Exception as e:
+        print("rotate_bound", e)
+
+
+# 获取图片旋转角度
+def get_minAreaRect(image):
+    try:
+        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+        ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
+        if get_platform() == "Windows":
+            cv2.imshow("binary", binary)
+            cv2.waitKey(0)
+        contours, hier = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+
+        # 绘制矩形
+        cv2.drawContours(image, contours, 0, (255, 0, 255), 3)
+        if get_platform() == "Windows":
+            cv2.imshow("image", image)
+            cv2.waitKey(0)
+
+        for c in contours:  #遍历轮廓
+            rect = cv2.minAreaRect(c)  #生成最小外接矩形
+            box_ = cv2.boxPoints(rect)
+            h = abs(box_[3, 1] - box_[1, 1])
+            w = abs(box_[3, 0] - box_[1, 0])
+            print("宽,高", w, h)
+            # 只保留需要的轮廓
+            if h > 3000 or w > 2200:
+                continue
+            if h < 2500 or w < 1500:
+                continue
+            # 计算最小面积矩形的坐标
+            box = cv2.boxPoints(rect)
+            # 将坐标规范化为整数
+            box = np.int0(box)
+            # 获取矩形相对于水平面的角度
+            angle = rect[2]
+            if angle > 0:
+                if abs(angle) > 45:
+                    angle = 90 - abs(angle)
+            else:
+                if abs(angle) > 45:
+                    angle = (90 - abs(angle))
+
+        print("轮廓数量", len(contours))
+        return image, box, angle
+
+    except Exception as e:
+        print("get_minAreaRect", traceback.print_exc())
+        return [-1], [-1], [-1]
+
+
+def get_table_line(image):
+    # 二值化
+    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+    binary = cv2.adaptiveThreshold(~gray, 255,
+                                   cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 203, -10)
+    # cv2.imshow("cell", binary)
+    # cv2.waitKey(0)
+
+    # 轮廓
+    kernel_row = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], np.float32)
+    kernel_col = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], np.float32)
+    # Sobel
+    # kernel_row = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], np.float32)
+    # kernel_col = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], np.float32)
+
+    # binary = cv2.filter2D(binary, -1, kernel=kernel)
+    binary_row = cv2.filter2D(binary, -1, kernel=kernel_row)
+    binary_col = cv2.filter2D(binary, -1, kernel=kernel_col)
+    # cv2.imshow("custom_blur_demo", binary)
+    # cv2.waitKey(0)
+
+    # rows, cols = binary.shape
+    # scale = 20
+    # # 识别横线
+    # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols//scale, 1))
+    # erodedcol = cv2.erode(binary_row, kernel, iterations=1)
+    # cv2.imshow("Eroded Image", erodedcol)
+    # cv2.waitKey(0)
+    # dilatedcol = cv2.dilate(erodedcol, kernel, iterations=3)
+    # cv2.imshow("Dilated Image", dilatedcol)
+    # cv2.waitKey(0)
+    #
+    # # 识别竖线
+    # scale = 20
+    # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, rows//scale))
+    # erodedrow = cv2.erode(binary_col, kernel, iterations=1)
+    # cv2.imshow("Eroded Image", erodedrow)
+    # cv2.waitKey(0)
+    # dilatedrow = cv2.dilate(erodedrow, kernel, iterations=3)
+    # cv2.imshow("Dilated Image", dilatedrow)
+    # cv2.waitKey(0)
+    #
+    # # 标识表格
+    # merge = cv2.add(dilatedcol, dilatedrow)
+    # cv2.imshow("add Image", merge)
+    # cv2.imwrite('table_outline.jpg', merge)
+    # cv2.waitKey(0)
+
+    return binary_row
+
+
+def detect_line(img):
+    try:
+        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+        edges = cv2.Canny(gray, 100, 1000)
+        # cv2.imshow("edges", edges)
+        # cv2.waitKey(0)
+
+        lines = cv2.HoughLines(edges, 1, np.pi/180, 200)
+        if lines is None or len(lines) == 0:
+            return 0
+
+        angle_list = []
+        for i in range(len(lines)):
+            # 第一个元素是距离rho
+            rho = lines[i][0][0]
+            # 第二个元素是角度theta
+            theta = lines[i][0][1]
+            a = np.cos(theta)
+            b = np.sin(theta)
+            x0 = a*rho
+            y0 = b*rho
+            x1 = float(x0 + 1000*(-b))
+            y1 = float(y0 + 1000*a)
+            x2 = float(x0 - 1000*(-b))
+            y2 = float(y0 - 1000*a)
+
+            if x2-x1 == 0 or y2-y1 == 0 or x2-x1 <= 0.01 or y2-y1 <= 0.01:
+                continue
+
+            k = -(y2-y1) / (x2-x1)
+            if abs(k) <= 5:
+                # h = math.atan(k)
+                # angle = math.degrees(h)
+                angle = np.arctan(k) * (180/math.pi)
+                # 调整
+                angle_list.append(angle)
+                # print(angle)
+            # cv2.line(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
+
+        angle_sum = 0
+        for a in angle_list:
+            angle_sum += a
+        angle_avg = angle_sum/len(angle_list)
+        print("angle_avg", angle_avg)
+
+        return angle_avg
+    except Exception as e:
+        print("detect_line", e)
+
+
+def get_rotated_image_old(image, output_path):
+    try:
+        angle = detect_line(image)
+        if angle != 0:
+            rotated = rotate_bound(image, angle)
+            # cv2.putText(rotated, "angle: {:.3f} ".format(angle),
+            #             (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
+
+            # show the output image
+            # cv2.imshow("imput", image)
+            # cv2.waitKey(0)
+            # cv2.imshow("output", rotated)
+            # cv2.waitKey(0)
+            print("[INFO] angle: {:.3f}".format(angle))
+            cv2.imwrite(output_path, rotated)
+            return True
+        else:
+            print("angle", angle)
+            return False
+    except Exception as e:
+        print("get_rotated_image", e)
+        return False
+
+
+def rotate_bound(image, angle):
+    # 获取宽高
+    (h, w) = image.shape[:2]
+    (cX, cY) = (w // 2, h // 2)
+
+    # 提取旋转矩阵 sin cos
+    M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
+    cos = np.abs(M[0, 0])
+    sin = np.abs(M[0, 1])
+
+    # 计算图像的新边界尺寸
+    nW = int((h * sin) + (w * cos))
+    #     nH = int((h * cos) + (w * sin))
+    nH = h
+
+    # 调整旋转矩阵
+    M[0, 2] += (nW / 2) - cX
+    M[1, 2] += (nH / 2) - cY
+
+    return cv2.warpAffine(image, M, (nW, nH),flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
+
+
+# 获取图片旋转角度
+def get_minAreaRect(image):
+    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+    gray = cv2.bitwise_not(gray)
+    thresh = cv2.threshold(gray, 0, 255,
+                           cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
+    coords = np.column_stack(np.where(thresh > 0))
+    return cv2.minAreaRect(coords)
+
+
+def get_rotated_image(image, output_path):
+    try:
+        image_temp, box, angle = get_minAreaRect(image)
+        if angle == [-1]:
+            return [-1]
+
+        # angle = get_minAreaRect(image)[-1]
+        if abs(angle) >= 15:
+            angle = 0
+        rotated = rotate_bound(image, angle)
+
+        # cv2.putText(rotated, "angle: {:.2f} ".format(angle),
+        #             (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
+
+        # show the output image
+        # print("[INFO] angle: {:.3f}".format(angle))
+        if not angle:
+            cv2.imwrite(output_path, rotated)
+        # cv2.imshow("input", image)
+        # cv2.waitKey(0)
+        # cv2.imshow("output", rotated)
+        # cv2.waitKey(0)
+        return True
+    except Exception as e:
+        print("get_rotated_image", e)
+        return [-1]
+
+
+if __name__ == '__main__':
+    temp_path = "temp/complex/8.png"
+    # temp_path = "temp/92d885dcb77411eb914300163e0ae709/92d8ef5eb77411eb914300163e0ae709_pdf_page1.png"
+    image = cv2.imread(temp_path)
+    get_rotated_image(image, temp_path)
+
+    # test()

BIN
format_convert/temp-0.5795441.jpg


BIN
format_convert/temp0.0.jpg


BIN
format_convert/temp0.jpg


+ 0 - 0
format_convert/temp1.1368683772161603e-13.jpg


+ 0 - 0
format_convert/temp107.52.jpg


BIN
format_convert/temp211.jpg


BIN
format_convert/temp232.61.jpg


+ 0 - 0
format_convert/temp31312.0.jpg


BIN
format_convert/temp316.63.jpg


BIN
format_convert/temp349.4635.jpg


+ 0 - 0
format_convert/temp350.0.jpg


BIN
format_convert/temp398.0.jpg


+ 0 - 0
format_convert/temp80.64.jpg


BIN
format_convert/temp90.0.jpg


+ 411 - 0
format_convert/test_ocr_interface.py

@@ -0,0 +1,411 @@
+import base64
+import copy
+import ctypes
+import gc
+import hashlib
+import inspect
+import multiprocessing
+import os
+import random
+import traceback
+from glob import glob, iglob
+import threading
+import time
+import urllib
+
+import psutil
+import requests
+import json
+import sys
+from multiprocessing import Process, Pool
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+from format_convert.convert import convert
+from ocr.ocr_interface import ocr, OcrModels
+from otr.otr_interface import otr, OtrModels
+from format_convert.judge_platform import get_platform
+
+
+class myThread(threading.Thread):
+    def __init__(self, threadName):
+        threading.Thread.__init__(self)
+        self.threadName = threadName
+
+    def run(self):
+        while True:
+            start_time = time.time()
+            test_convert()
+            print(self.threadName, "finish!", time.time()-start_time)
+
+
+class myThread_appendix(threading.Thread):
+    def __init__(self, threadName, _list):
+        threading.Thread.__init__(self)
+        self.threadName = threadName
+        self._list = _list
+
+    def run(self):
+        start_time = time.time()
+        test_appendix_downloaded(self._list)
+        print(self.threadName, "finish!", time.time()-start_time)
+
+
+def test_ocr():
+    with open("test_files/开标记录表3_page_0.png", "rb") as f:
+        base64_data = base64.b64encode(f.read())
+    # print(base64_data)
+    url = local_url + ":15011" + '/ocr'
+    # url = 'http://127.0.0.1:15013/ocr'
+
+    r = requests.post(url, data=base64_data, timeout=2000)
+    # print("test:", r.content.decode("utf-8"))
+
+
+def test_otr():
+    with open("test_files/开标记录表3_page_0.png", "rb") as f:
+        base64_data = base64.b64encode(f.read())
+    # print(base64_data)
+    url = local_url + ":15017" + '/otr'
+    # url = 'http://127.0.0.1:15013/ocr'
+
+    r = requests.post(url, data=base64_data, timeout=2000)
+    # print("test:", r.content.decode("utf-8"))
+
+
+def test_convert():
+    # path = "开标记录表3.pdf"
+    # path = "test_files/开标记录表3_page_0.png"
+    # path = "test_files/1.docx"
+    # path = '光明食品(集团)有限公司2017年度经审计的合并及母公司财务报表.pdf'
+    # path = '光明.pdf'
+    # path = 'D:/BIDI_DOC/比地_文档/Oracle11g学生成绩管理系统.docx'
+    # path = "C:\\Users\\Administrator\\Desktop\\1600825332753119.doc"
+    # path = "temp/complex/8.png"
+    # path = "合同备案.doc"
+    # path = "1.png"
+    # path = "1.pdf"
+    # path = "(清单)衢州市第二人民医院二期工程电缆采购项目.xls"
+    # path = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \
+    #     "\\00fb43acbc7e11eb836000163e0ae709.png"
+    # path = "D:\\BIDI_DOC\\比地_文档\\8a949486788ccc6d017969f189301d41.pdf"
+    # path = "be8a17f2cc1b11eba26800163e0857b6.docx"
+    # path = "江苏省通州中等专业学校春节物资采购公 告.docx"
+    # path = "test_files/1.zip"
+    # path = "C:\\Users\\Administrator\\Desktop\\33f52292cdad11ebb58300163e0857b6.zip"
+    path = "C:\\Users\\Administrator\\Desktop\\Test_Interface\\1623392355541.zip"
+    with open(path, "rb") as f:
+        base64_data = base64.b64encode(f.read())
+    # print(base64_data)
+    url = _url + '/convert'
+    # url = 'http://127.0.0.1:15014/convert'
+    # headers = {'Content-Type': 'application/json'}
+    headers = {
+        'Connection': 'keep-alive'
+    }
+    data = urllib.parse.urlencode({"file": base64_data, "type": path.split(".")[-1]}).encode('utf-8')
+    req = urllib.request.Request(url, data=data, headers=headers)
+    with urllib.request.urlopen(req) as response:
+        _dict = eval(response.read().decode("utf-8"))
+    result = _dict.get("result")
+    is_success = _dict.get("is_success")
+    print("is_success", is_success)
+    print("len(result)", len(result))
+    for i in range(len(result)):
+        print("=================")
+        print(result[i])
+        print("-----------------")
+    # print(len(eval(r.content.decode("utf-8")).get("result")))
+    # print(r.content)
+
+
+def test_appendix_downloaded(_list):
+    # 直接使用下载好的附件
+    i = 0
+    # for docid_file in glob("/mnt/html_files/*"):
+    for docid_file in _list:
+        if i % 100 == 0:
+            print("Loop", i)
+
+        # print(docid_file)
+        for file_path in iglob(docid_file + "/*"):
+            print(file_path)
+            with open(file_path, "rb") as f:
+                base64_data = base64.b64encode(f.read())
+            url = _url + '/convert'
+            # print(url)
+
+            try:
+                # headers = {
+                #     'Connection': 'keep-alive'
+                # }
+                # data = urllib.parse.urlencode({"file": base64_data, "type": file_path.split(".")[-1]}).encode('utf-8')
+                # req = urllib.request.Request(url, data=data, headers=headers)
+                # with urllib.request.urlopen(req, timeout=2000) as response:
+                #     _dict = eval(response.read().decode("utf-8"))
+
+                # timeout=2000
+                r = requests.post(url, data={"file": base64_data,
+                                             "type": file_path.split(".")[-1]}, timeout=2000)
+                _dict = eval(r.content.decode("utf-8"))
+                print("is_success:", _dict.get("is_success"))
+            except Exception as e:
+                print("docid " + str(docid_file) + " time out!", e)
+        i += 1
+
+
+def test_convert_maxcompute():
+    try:
+        ocr_model = OcrModels().get_model()
+        otr_model = OtrModels().get_model()
+
+        path_list = []
+        path_suffix = "未命名4.pdf"
+        if get_platform() == "Windows":
+            path_prefix = "C:\\Users\\Administrator\\Desktop\\Test_ODPS\\"
+            # path_prefix = "C:\\Users\\Administrator\\Desktop\\"
+            path_list.append(path_prefix + path_suffix)
+
+        else:
+            path_list.append(path_suffix)
+
+        result_list = []
+        for path in path_list:
+            with open(path, "rb") as f:
+                base64_data = base64.b64encode(f.read())
+                # print("------------")
+                # print(base64_data)
+                # print('------------')
+            data = {"file": base64_data, "type": path.split(".")[-1]}
+
+            result_dict = convert(data, ocr_model, otr_model)
+
+            print("garbage object num:%d" % (len(gc.garbage)))
+            _unreachable = gc.collect()
+
+            print("unreachable object num:%d" % (_unreachable))
+            print("garbage object num:%d" % (len(gc.garbage)))
+
+            result_list.append(result_dict)
+
+        for result_dict in result_list:
+            result = result_dict.get("result_text")
+            is_success = result_dict.get("is_success")
+            for i in range(len(result)):
+                print("=================", "is_success", is_success, i, "in", len(result))
+
+            #     _dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
+            #     _dir = os.path.abspath(_dir) + os.sep
+            #     if i == 0:
+            #         with open(_dir + "result.html", "w") as ff:
+            #             ff.write(result[i])
+            #     else:
+            #         with open(_dir + "result.html", "a") as ff:
+            #             ff.write("<div>=================================================</div>")
+            #             ff.write(result[i])
+            # print("write result to", _dir + "result.html")
+
+        del otr_model
+        del ocr_model
+        gc.collect()
+    except Exception as e:
+        print(e)
+        usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
+        print("memory 2", str(usage))
+
+
+def getMDFFromFile(path):
+    _length = 0
+    try:
+        _md5 = hashlib.md5()
+        with open(path, "rb") as ff:
+            while True:
+                data = ff.read(4096)
+                if not data:
+                    break
+                _length += len(data)
+                _md5.update(data)
+        return _md5.hexdigest(), _length
+    except Exception as e:
+        traceback.print_exc()
+        return None, _length
+
+
+def get_base64():
+    path = "C:\\Users\\Administrator\\Desktop\\Test_ODPS\\1623430252934.doc"
+    with open(path, "rb") as f:
+        base64_data = base64.b64encode(f.read())
+        print("------------")
+        print(base64_data)
+        print('------------')
+    print(getMDFFromFile(path))
+
+
+def test_init_model():
+    class MyThread(threading.Thread):
+        def __init__(self):
+            super(MyThread, self).__init__()
+            self.ocr_model = OcrModels().get_model()
+            self.otr_model = OtrModels().get_model()
+
+        def run(self):
+            self.result = random.randint(1, 10)
+
+        def get_result(self):
+            return self.result
+
+        def _async_raise(self, tid, exctype):
+            """raises the exception, performs cleanup if needed"""
+            tid = ctypes.c_long(tid)
+            if not inspect.isclass(exctype):
+                exctype = type(exctype)
+            res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
+            if res == 0:
+                raise ValueError("invalid thread id")
+            elif res != 1:
+                ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
+                raise SystemError("PyThreadState_SetAsyncExc failed")
+
+        def stop_thread(self, tid):
+            self._async_raise(tid, SystemExit)
+
+    class GetModel:
+        def __init__(self):
+
+            # usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+            # print("memory 2", str(usage))
+            return
+
+        def process(self):
+            thread = MyThread()
+            thread.start()
+            thread.join()
+            result = thread.get_result()
+            print(result)
+            if thread.is_alive():
+                thread.stop_thread(thread.ident)
+            # usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+            # print("memory 3", str(usage))
+
+    m = GetModel()
+    m.process()
+
+
+# spawn模式复制进程,否则模型挂起
+# multiprocessing.set_start_method('spawn', force=True)
+ocr_model = ""
+otr_model = ""
+class TestProcess:
+    def __init__(self):
+        super(TestProcess, self).__init__()
+        self.process_num = 2
+        self.data_list = []
+        self.result_list = []
+        self.current_data = ""
+        self.result_num = 0
+
+    def child_process_1(self):
+        # 初始化模型
+        globals().update({"ocr_model": OcrModels().get_model()})
+        globals().update({"otr_model": OtrModels().get_model()})
+
+        # 循环转换
+        for data in self.data_list:
+            self.current_data = data
+            # self.child_process_2()
+            p = Process(target=self.child_process_2)
+            p.start()
+            p.join()
+            if p.is_alive():
+                print("p.close")
+                p.close()
+
+        # 初始化
+        self.data_list = []
+
+        # 删除之前模型
+        global ocr_model, otr_model
+        del ocr_model
+        del otr_model
+        gc.collect()
+
+    def child_process_2(self):
+        global ocr_model, otr_model
+        result = convert(self.current_data, ocr_model, otr_model)
+        print("result", result.get("is_success"))
+        self.result_list.append(result)
+        print("len(self.result_list)======================", len(self.result_list))
+        self.result_num += 1
+
+    def process(self, path_list):
+        for path in path_list:
+            with open(path, "rb") as f:
+                base64_data = base64.b64encode(f.read())
+            data = {"file": base64_data, "type": path.split(".")[-1]}
+            self.data_list.append(data)
+
+        # 攒够10条数据执行
+        if len(self.data_list) == self.process_num:
+
+            p = Process(target=self.child_process_1)
+            p.start()
+            p.join()
+            p.close()
+
+            print("init data_list result_list!")
+            self.data_list = []
+            print("self.result_num", self.result_num)
+
+
+def test_convert_process():
+    t = TestProcess()
+    t.process(["1623430252934.doc", "1623430252934.doc"])
+
+    usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
+    print("----- memory info start - test_convert_process" + " - " + str(usage) + " GB")
+
+    # t.process(["1.docx", "1.docx"])
+    # usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
+    # print("----- memory info start - test_convert_process" + " - " + str(usage) + " GB")
+
+
+gpu_url = "http://192.168.2.101"
+memory_url = "http://47.97.90.190"
+local_url = "http://127.0.0.1"
+production_url = "http://47.98.57.0"
+
+_url = local_url + ":15015"
+
+if __name__ == '__main__':
+    # test_convert()
+    # test_convert_process()
+    test_convert_maxcompute()
+    # test_init_model()
+    # test_ocr()
+    # test_otr()
+    # test_appendix_downloaded()
+    # get_base64()
+
+    # print(getMDFFromFile("C:\\Users\\Administrator\\Desktop\\Test_ODPS\\1624900794475.docx"))
+
+    # 多线程调用 #####################################
+    # threads_num = 30
+    # thread_list = []
+    # glob_list = glob("html_files/*")
+    # sub_num = int(len(glob_list) / threads_num)
+    # print(len(glob_list), sub_num)
+    #
+    # for i in range(threads_num):
+    #     if i == threads_num - 1:
+    #         _list = glob_list[i*sub_num:]
+    #     else:
+    #         _list = glob_list[i*sub_num:(i+1)*sub_num]
+    #     print(i*sub_num, len(_list))
+    #
+    #     thread = myThread_appendix("Thread-"+str(i), _list)
+    #     thread_list.append(thread)
+    #
+    # for thread in thread_list:
+    #     thread.start()
+    # for thread in thread_list:
+    #     thread.join()

+ 8 - 0
format_convert/test_walk.py

@@ -0,0 +1,8 @@
+import os
+file_list = []
+for root, dirs, files in os.walk("./", topdown=False):
+    for name in dirs:
+        file_list.append(os.path.join(root, name) + os.sep)
+    for name in files:
+        file_list.append(os.path.join(root, name))
+print(file_list)

+ 14 - 0
format_convert/testswf.py

@@ -0,0 +1,14 @@
+
+from format_convert.swf.movie import SWF
+from format_convert.swf.export import SVGExporter
+
+file = open("ab.swf",'rb')
+
+_swf = SWF(file)
+svg_exporter = SVGExporter()
+
+# export!
+svg = _swf.export(svg_exporter)
+
+# save the SVG
+open('svg', 'wb').write(svg.read())

+ 198 - 0
format_convert/timeout_decorator.py

@@ -0,0 +1,198 @@
+"""
+Timeout decorator.
+
+    :copyright: (c) 2012-2013 by PN.
+    :license: MIT, see LICENSE for more details.
+"""
+
+from __future__ import print_function
+from __future__ import unicode_literals
+from __future__ import division
+
+import sys
+import time
+import multiprocessing
+import signal
+from functools import wraps
+
+############################################################
+# Timeout
+############################################################
+
+# http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
+# Used work of Stephen "Zero" Chappell <Noctis.Skytower@gmail.com>
+# in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
+from format_convert.judge_platform import get_platform
+
+
+class TimeoutError(AssertionError):
+
+    """Thrown when a timeout occurs in the `timeout` context manager."""
+
+    def __init__(self, value="Timed Out"):
+        self.value = value
+
+    def __str__(self):
+        return repr(self.value)
+
+
+def _raise_exception(exception, exception_message):
+    """ This function checks if a exception message is given.
+
+    If there is no exception message, the default behaviour is maintained.
+    If there is an exception message, the message is passed to the exception with the 'value' keyword.
+    """
+    if exception_message is None:
+        raise exception()
+    else:
+        raise exception(exception_message)
+
+
+def timeout(seconds=None, use_signals=True, timeout_exception=TimeoutError, exception_message=None):
+    """Add a timeout parameter to a function and return it.
+
+    :param seconds: optional time limit in seconds or fractions of a second. If None is passed, no timeout is applied.
+        This adds some flexibility to the usage: you can disable timing out depending on the settings.
+    :type seconds: float
+    :param use_signals: flag indicating whether signals should be used for timing function out or the multiprocessing
+        When using multiprocessing, timeout granularity is limited to 10ths of a second.
+    :type use_signals: bool
+
+    :raises: TimeoutError if time limit is reached
+
+    It is illegal to pass anything other than a function as the first
+    parameter. The function is wrapped and returned to the caller.
+    """
+    def decorate(function):
+        if get_platform() == "Windows":
+            @wraps(function)
+            def new_function(*args, **kwargs):
+                return function(*args, **kwargs)
+            return new_function
+
+        else:
+            if use_signals:
+                def handler(signum, frame):
+                    _raise_exception(timeout_exception, exception_message)
+
+                @wraps(function)
+                def new_function(*args, **kwargs):
+                    new_seconds = kwargs.pop('timeout', seconds)
+                    if new_seconds:
+                        old = signal.signal(signal.SIGALRM, handler)
+                        signal.setitimer(signal.ITIMER_REAL, new_seconds)
+
+                    if not seconds:
+                        return function(*args, **kwargs)
+
+                    try:
+                        return function(*args, **kwargs)
+                    finally:
+                        if new_seconds:
+                            signal.setitimer(signal.ITIMER_REAL, 0)
+                            signal.signal(signal.SIGALRM, old)
+                return new_function
+            else:
+                @wraps(function)
+                def new_function(*args, **kwargs):
+                    timeout_wrapper = _Timeout(function, timeout_exception, exception_message, seconds)
+                    return timeout_wrapper(*args, **kwargs)
+                return new_function
+
+    return decorate
+
+
+# 装饰器包装为类,方便Pickle
+class TimeoutClass:
+    def __init__(self, func, seconds, timeout_exception):
+        self.func = func
+        self.seconds = seconds
+        self.timeout_exception = timeout_exception
+
+    def run(self, *args, **kwargs):
+        timeout_wrapper = _Timeout(self.func, self.timeout_exception, None, self.seconds)
+        return timeout_wrapper(*args, **kwargs)
+
+
+def _target(queue, function, *args, **kwargs):
+    """Run a function with arguments and return output via a queue.
+
+    This is a helper function for the Process created in _Timeout. It runs
+    the function with positional arguments and keyword arguments and then
+    returns the function's output by way of a queue. If an exception gets
+    raised, it is returned to _Timeout to be raised by the value property.
+    """
+    try:
+        queue.put((True, function(*args, **kwargs)))
+    except:
+        queue.put((False, sys.exc_info()[1]))
+
+
+class _Timeout(object):
+
+    """Wrap a function and add a timeout (limit) attribute to it.
+
+    Instances of this class are automatically generated by the add_timeout
+    function defined above. Wrapping a function allows asynchronous calls
+    to be made and termination of execution after a timeout has passed.
+    """
+
+    def __init__(self, function, timeout_exception, exception_message, limit):
+        """Initialize instance in preparation for being called."""
+        self.__limit = limit
+        self.__function = function
+        self.__timeout_exception = timeout_exception
+        self.__exception_message = exception_message
+        self.__name__ = function.__name__
+        self.__doc__ = function.__doc__
+        self.__timeout = time.time()
+        self.__process = multiprocessing.Process()
+        self.__queue = multiprocessing.Queue()
+
+    def __call__(self, *args, **kwargs):
+        """Execute the embedded function object asynchronously.
+
+        The function given to the constructor is transparently called and
+        requires that "ready" be intermittently polled. If and when it is
+        True, the "value" property may then be checked for returned data.
+        """
+        self.__limit = kwargs.pop('timeout', self.__limit)
+        self.__queue = multiprocessing.Queue(1)
+        args = (self.__queue, self.__function) + args
+
+        multiprocessing.set_start_method("spawn", force=True)
+        self.__process = multiprocessing.Process(target=_target,
+                                                 args=args,
+                                                 kwargs=kwargs)
+        self.__process.daemon = True
+        self.__process.start()
+        if self.__limit is not None:
+            self.__timeout = self.__limit + time.time()
+        while not self.ready:
+            time.sleep(0.01)
+        return self.value
+
+    def cancel(self):
+        """Terminate any possible execution of the embedded function."""
+        if self.__process.is_alive():
+            print("terminate process", self.__process.pid)
+            # self.__process.terminate()
+            self.__process.kill()
+
+        _raise_exception(self.__timeout_exception, self.__exception_message)
+
+    @property
+    def ready(self):
+        """Read-only property indicating status of "value" property."""
+        if self.__limit and self.__timeout < time.time():
+            self.cancel()
+        return self.__queue.full() and not self.__queue.empty()
+
+    @property
+    def value(self):
+        """Read-only property containing data returned from function."""
+        if self.ready is True:
+            flag, load = self.__queue.get()
+            if flag:
+                return load
+            raise load

+ 0 - 0
ocr/model/2.0/cls/inference.pdiparams


+ 0 - 0
ocr/model/2.0/cls/inference.pdiparams.info


+ 0 - 0
ocr/model/2.0/cls/inference.pdmodel


+ 0 - 0
ocr/model/2.0/det/inference.pdiparams


+ 0 - 0
ocr/model/2.0/det/inference.pdiparams.info


+ 0 - 0
ocr/model/2.0/det/inference.pdmodel


+ 0 - 0
ocr/model/2.0/rec/ch/inference.pdiparams


+ 0 - 0
ocr/model/2.0/rec/ch/inference.pdiparams.info


+ 0 - 0
ocr/model/2.0/rec/ch/inference.pdmodel


+ 0 - 0
ocr/model/2.0/rec/ch/origin_model/mobile/inference.pdiparams


+ 0 - 0
ocr/model/2.0/rec/ch/origin_model/mobile/inference.pdiparams.info


+ 0 - 0
ocr/model/2.0/rec/ch/origin_model/mobile/inference.pdmodel


+ 0 - 0
ocr/model/2.0/rec/ch/origin_model/server/inference.pdiparams


+ 0 - 0
ocr/model/2.0/rec/ch/origin_model/server/inference.pdiparams.info


+ 0 - 0
ocr/model/2.0/rec/ch/origin_model/server/inference.pdmodel


+ 0 - 0
ocr/model/2.0/rec/ch/production_model/mobile/inference.pdiparams


+ 0 - 0
ocr/model/2.0/rec/ch/production_model/mobile/inference.pdiparams.info


+ 0 - 0
ocr/model/2.0/rec/ch/production_model/mobile/inference.pdmodel


+ 54 - 0
ocr/my_infer.py

@@ -0,0 +1,54 @@
+import cv2
+from PIL import Image
+from paddleocr import PaddleOCR
+from tools.infer.utility import draw_ocr
+import numpy as np
+from format_convert.convert import remove_red_seal, remove_underline
+
+# path = "../temp/complex/710.png"
+# path = "../test_files/开标记录表3_page_0.png"
+# path = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \
+#         "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png"
+# path = "../去章文字.jpg"
+# path = "../1.jpg"
+# path = "../real1.png"
+path = "../temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png"
+path = "../翻转1.jpg"
+# 去掉公章
+# image_np = cv2.imread(path)
+# cv2.imshow("origin image", image_np)
+# cv2.waitKey(0)
+# image_np = remove_red_seal(image_np)
+# cv2.imwrite("../去章文字.jpg", image_np)
+
+# 去掉下划线
+# image_np = cv2.imread(path)
+# remove_underline(image_np)
+
+with open(path, "rb") as f:
+    image = f.read()
+ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
+image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
+# # 将bgr转为rbg
+np_images = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+# np_images = [cv2.imread(img_data)]
+results = ocr_model.ocr(np_images, det=True, rec=True, cls=True)
+
+bbox_list = []
+text_list = []
+score_list = []
+for line in results:
+    text_list.append(line[-1][0])
+    bbox_list.append(line[0])
+    score_list.append(line[-1][1])
+    # print("len(text_list)", len(text_list))
+    # print("len(bbox_list)", len(bbox_list))
+    # print("score_list", score_list)
+
+image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+boxes = bbox_list
+
+image = draw_ocr(image, boxes, text_list, score_list, drop_score=0.2)
+print(type(image))
+image = Image.fromarray(image)
+image.show("image")

+ 75 - 0
ocr/my_infer_hub.py

@@ -0,0 +1,75 @@
+import base64
+import json
+import re
+import sys
+import time
+import paddlehub as hub
+import cv2
+from PIL import Image
+import logging
+import numpy as np
+
+use_model = "mobile"
+use_gpu = False
+
+# img_data = "./1.jpg"
+img_data = "D:/Project/PaddleOCR-release-2.0/train_data/bidi_data/orgs_data/train/text_304.jpg"
+np_images = [cv2.imread(img_data)]
+
+
+def only_recognize():
+    ocr = hub.Module(name="chinese_text_detection_db_"+use_model)
+    results = ocr
+
+
+def only_detect():
+    ocr = hub.Module(name="chinese_text_detection_db_"+use_model)
+    results = ocr.detect_text(
+        images=np_images,           # 图片数据,ndarray.shape 为 [H, W, C],BGR格式;
+        use_gpu=use_gpu,            # 是否使用 GPU;若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量
+        output_dir='../ocr_result',  # 图片的保存路径,默认设为 ocr_result;
+        visualization=True,         # 是否将识别结果保存为图片文件;
+        box_thresh=0.5              # 检测文本框置信度的阈值;
+        )                           # 识别中文文本置信度的阈值;
+    for result in results:
+        print(results)
+
+
+def detect_and_recognize():
+    ocr = hub.Module(name="chinese_ocr_db_crnn_"+use_model)
+    results = ocr.recognize_text(
+        images=np_images,         # 图片数据,ndarray.shape 为 [H, W, C],BGR格式;
+        use_gpu=use_gpu,            # 是否使用 GPU;若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量
+        output_dir='../ocr_result',  # 图片的保存路径,默认设为 ocr_result;
+        visualization=True,       # 是否将识别结果保存为图片文件;
+        box_thresh=0.5,           # 检测文本框置信度的阈值;
+        text_thresh=0.0)          # 识别中文文本置信度的阈值;
+
+    for result in results:
+        data = result['data']
+        save_path = result['save_path']
+        for infomation in data:
+            if infomation['text'] == "":
+                print("no text")
+                continue
+            print('text: ', infomation['text'], '\nconfidence: ',
+                  infomation['confidence'], '\ntext_box_position: ',
+                  infomation['text_box_position'])
+
+
+def image_bigger():
+    img = cv2.imread("./1.jpg", -1)
+    height, weight = img.shape[:2]
+    print(height, weight)
+    fx = 0.7
+    fy = 1.2
+    enlarge = cv2.resize(img, (0, 0), fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC)
+    print(enlarge.shape[:2])
+    cv2.imwrite("./1_1.jpg", enlarge)
+    # img.save("./1_1.jpg", "jpeg")
+
+
+if __name__ == '__main__':
+    only_detect()
+    # image_bigger()
+    # detect_and_recognize()

+ 155 - 0
ocr/ocr_interface.py

@@ -0,0 +1,155 @@
+import base64
+import json
+import multiprocessing as mp
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+import time
+import traceback
+from multiprocessing.context import Process
+import cv2
+import requests
+import logging
+import numpy as np
+os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
+from ocr.paddleocr import PaddleOCR
+
+
+logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+def log(msg):
+    '''
+    @summary:打印信息
+    '''
+    logger.info(msg)
+
+
+def ocr(data, ocr_model):
+    try:
+        img_data = base64.b64decode(data)
+        text = picture2text(img_data, ocr_model)
+        return text
+    except TimeoutError:
+        raise TimeoutError
+
+
+flag = 0
+def picture2text(img_data, ocr_model):
+    logging.info("into ocr_interface picture2text")
+    try:
+        start_time = time.time()
+        # 二进制数据流转np.ndarray [np.uint8: 8位像素]
+        img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
+        # 将bgr转为rbg
+        try:
+            np_images = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+        except cv2.error as e:
+            if "src.empty()" in str(e):
+                logging.info("ocr_interface picture2text image is empty!")
+                return {"text": str([]), "bbox": str([])}
+        # resize
+        # cv2.imshow("before resize", np_images)
+        # print("np_images.shape", np_images.shape)
+
+        # best_h, best_w = get_best_predict_size(np_images)
+        # np_images = cv2.resize(np_images, (best_w, best_h), interpolation=cv2.INTER_AREA)
+
+        # cv2.imshow("after resize", np_images)
+        # print("np_images.shape", np_images.shape)
+        # cv2.waitKey(0)
+
+        # 预测
+        results = ocr_model.ocr(np_images, det=True, rec=True, cls=True)
+
+        # 循环每张图片识别结果
+        text_list = []
+        bbox_list = []
+        for line in results:
+            # print("ocr_interface line", line)
+            text_list.append(line[-1][0])
+            bbox_list.append(line[0])
+
+        # 查看bbox
+        # img = np.zeros((np_images.shape[1], np_images.shape[0]), np.uint8)
+        # img.fill(255)
+        # for box in bbox_list:
+        #     print(box)
+        #     cv2.rectangle(img, (int(box[0][0]), int(box[0][1])),
+        #                   (int(box[2][0]), int(box[2][1])), (0, 0, 255), 1)
+        # cv2.imshow("bbox", img)
+        # cv2.waitKey(0)
+
+        logging.info("ocr model use time: " + str(time.time()-start_time))
+        return {"text": str(text_list), "bbox": str(bbox_list)}
+
+    except TimeoutError:
+        raise TimeoutError
+    except Exception as e:
+        logging.info("picture2text error!")
+        print("picture2text", traceback.print_exc())
+        return {"text": str([]), "bbox": str([])}
+
+
+def get_best_predict_size(image_np):
+    sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
+
+    min_len = 10000
+    best_height = sizes[0]
+    for height in sizes:
+        if abs(image_np.shape[0] - height) < min_len:
+            min_len = abs(image_np.shape[0] - height)
+            best_height = height
+
+    min_len = 10000
+    best_width = sizes[0]
+    for width in sizes:
+        if abs(image_np.shape[1] - width) < min_len:
+            min_len = abs(image_np.shape[1] - width)
+            best_width = width
+
+    return best_height, best_width
+
+
+class OcrModels:
+    def __init__(self):
+        try:
+            self.ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
+        except:
+            print(traceback.print_exc())
+            raise RuntimeError
+
+    def get_model(self):
+        return self.ocr_model
+
+
+if __name__ == '__main__':
+    # if len(sys.argv) == 2:
+    #     port = int(sys.argv[1])
+    # else:
+    #     port = 15011
+    #
+    # app.run(host='0.0.0.0', port=port, threaded=False, debug=False)
+    # log("OCR running")
+    # file_path = "C:/Users/Administrator/Desktop/error1.png"
+    file_path = "1.png"
+
+    with open(file_path, "rb") as f:
+        file_bytes = f.read()
+    file_base64 = base64.b64encode(file_bytes)
+
+    ocr_model = OcrModels().get_model()
+    result = ocr(file_base64, ocr_model)
+    result = ocr(file_base64, ocr_model)
+
+    text_list = eval(result.get("text"))
+    box_list = eval(result.get("bbox"))
+
+    new_list = []
+    for i in range(len(text_list)):
+        new_list.append([text_list[i], box_list[i]])
+
+    # print(new_list[0][1])
+    new_list.sort(key=lambda x: (x[1][1][0], x[1][0][0]))
+
+    for t in new_list:
+        print(t[0])

+ 362 - 0
ocr/paddleocr.py

@@ -0,0 +1,362 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+__dir__ = os.path.dirname(__file__)
+sys.path.append(os.path.join(__dir__, ''))
+project_path = os.path.abspath(__dir__)
+# project_path = ""
+
+import cv2
+import numpy as np
+from pathlib import Path
+import tarfile
+import requests
+from tqdm import tqdm
+
+os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
+from ocr.tools.infer import predict_system
+from ocr.ppocr.utils.logging import get_logger
+
+logger = get_logger()
+from ocr.ppocr.utils.utility import check_and_read_gif, get_image_file_list
+
+__all__ = ['PaddleOCR']
+
+model_urls = {
+    'det':
+    'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
+    'rec': {
+        'ch': {
+            'url':
+            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
+            'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
+        },
+        'en': {
+            'url':
+            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
+            'dict_path': './ppocr/utils/dict/en_dict.txt'
+        },
+        'french': {
+            'url':
+            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
+            'dict_path': './ppocr/utils/dict/french_dict.txt'
+        },
+        'german': {
+            'url':
+            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
+            'dict_path': './ppocr/utils/dict/german_dict.txt'
+        },
+        'korean': {
+            'url':
+            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
+            'dict_path': './ppocr/utils/dict/korean_dict.txt'
+        },
+        'japan': {
+            'url':
+            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
+            'dict_path': './ppocr/utils/dict/japan_dict.txt'
+        }
+    },
+    'cls':
+    'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
+}
+
+SUPPORT_DET_MODEL = ['DB']
+VERSION = 2.0
+SUPPORT_REC_MODEL = ['CRNN']
+# BASE_DIR = os.path.expanduser("~/.paddleocr/")
+BASE_DIR = project_path + "/model/"
+
+
+def download_with_progressbar(url, save_path):
+    response = requests.get(url, stream=True)
+    total_size_in_bytes = int(response.headers.get('content-length', 0))
+    block_size = 1024  # 1 Kibibyte
+    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+    with open(save_path, 'wb') as file:
+        for data in response.iter_content(block_size):
+            progress_bar.update(len(data))
+            file.write(data)
+    progress_bar.close()
+    if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
+        logger.error("Something went wrong while downloading models")
+        sys.exit(0)
+
+
+def maybe_download(model_storage_directory, url):
+    # using custom model
+    tar_file_name_list = [
+        'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
+    ]
+    if not os.path.exists(
+            os.path.join(model_storage_directory, 'inference.pdiparams')
+    ) or not os.path.exists(
+            os.path.join(model_storage_directory, 'inference.pdmodel')):
+        tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
+        print('download {} to {}'.format(url, tmp_path))
+        os.makedirs(model_storage_directory, exist_ok=True)
+        download_with_progressbar(url, tmp_path)
+        with tarfile.open(tmp_path, 'r') as tarObj:
+            for member in tarObj.getmembers():
+                filename = None
+                for tar_file_name in tar_file_name_list:
+                    if tar_file_name in member.name:
+                        filename = tar_file_name
+                if filename is None:
+                    continue
+                file = tarObj.extractfile(member)
+                with open(
+                        os.path.join(model_storage_directory, filename),
+                        'wb') as f:
+                    f.write(file.read())
+        os.remove(tmp_path)
+
+
+def parse_args(mMain=True, add_help=True):
+    import argparse
+
+    def str2bool(v):
+        return v.lower() in ("true", "t", "1")
+
+    if mMain:
+        parser = argparse.ArgumentParser(add_help=add_help)
+        # params for prediction engine
+        parser.add_argument("--use_gpu", type=str2bool, default=True)
+        parser.add_argument("--ir_optim", type=str2bool, default=True)
+        parser.add_argument("--use_tensorrt", type=str2bool, default=False)
+        parser.add_argument("--gpu_mem", type=int, default=8000)
+
+        # params for text detector
+        parser.add_argument("--image_dir", type=str)
+        parser.add_argument("--det_algorithm", type=str, default='DB')
+        parser.add_argument("--det_model_dir", type=str, default=None)
+        parser.add_argument("--det_limit_side_len", type=float, default=960)
+        parser.add_argument("--det_limit_type", type=str, default='max')
+
+        # DB parmas
+        parser.add_argument("--det_db_thresh", type=float, default=0.3)
+        parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
+        parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
+        parser.add_argument("--use_dilation", type=bool, default=False)
+
+        # EAST parmas
+        parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
+        parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
+        parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
+
+        # params for text recognizer
+        parser.add_argument("--rec_algorithm", type=str, default='CRNN')
+        parser.add_argument("--rec_model_dir", type=str, default=None)
+        parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
+        parser.add_argument("--rec_char_type", type=str, default='ch')
+        parser.add_argument("--rec_batch_num", type=int, default=30)
+        parser.add_argument("--max_text_length", type=int, default=25)
+        parser.add_argument("--rec_char_dict_path", type=str, default=None)
+        parser.add_argument("--use_space_char", type=bool, default=True)
+        parser.add_argument("--drop_score", type=float, default=0.5)
+
+        # params for text classifier
+        parser.add_argument("--cls_model_dir", type=str, default=None)
+        parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
+        parser.add_argument("--label_list", type=list, default=['0', '180'])
+        parser.add_argument("--cls_batch_num", type=int, default=30)
+        parser.add_argument("--cls_thresh", type=float, default=0.9)
+
+        parser.add_argument("--enable_mkldnn", type=bool, default=False)
+        parser.add_argument("--use_zero_copy_run", type=bool, default=False)
+        parser.add_argument("--use_pdserving", type=str2bool, default=False)
+
+        parser.add_argument("--lang", type=str, default='ch')
+        parser.add_argument("--det", type=str2bool, default=True)
+        parser.add_argument("--rec", type=str2bool, default=True)
+        parser.add_argument("--use_angle_cls", type=str2bool, default=False)
+        return parser.parse_args()
+    else:
+        return argparse.Namespace(
+            use_gpu=False,
+            ir_optim=True,
+            use_tensorrt=False,
+            gpu_mem=8000,
+            image_dir='',
+            det_algorithm='DB',
+            det_model_dir=None,
+            det_limit_side_len=1280,
+            det_limit_type='max',
+            det_db_thresh=0.5,
+            # det_db_box_thresh 漏行 调小
+            det_db_box_thresh=0.5,
+            # det_db_unclip_ratio 检测框的贴近程度
+            det_db_unclip_ratio=2,
+            # 对文字膨胀操作
+            use_dilation=False,
+            det_east_score_thresh=0.8,
+            det_east_cover_thresh=0.1,
+            det_east_nms_thresh=0.2,
+            rec_algorithm='CRNN',
+            rec_model_dir=None,
+            rec_image_shape="3, 32, 1000",
+            rec_char_type='ch',
+            rec_batch_num=30,
+            max_text_length=128,
+            rec_char_dict_path='ocr/ppocr/utils/ppocr_keys_v1.txt',
+            use_space_char=True,
+            drop_score=0.5,
+            cls_model_dir=None,
+            cls_image_shape="3, 32, 1000",
+            label_list=['0', '180'],
+            cls_batch_num=30,
+            cls_thresh=0.9,
+            enable_mkldnn=False,
+            use_zero_copy_run=True,
+            use_pdserving=False,
+            lang='ch',
+            det=True,
+            rec=True,
+            use_angle_cls=False)
+
+
+class PaddleOCR(predict_system.TextSystem):
+    def __init__(self, **kwargs):
+        """
+        paddleocr package
+        args:
+            **kwargs: other params show in paddleocr --help
+        """
+        postprocess_params = parse_args(mMain=False, add_help=False)
+        postprocess_params.__dict__.update(**kwargs)
+        self.use_angle_cls = postprocess_params.use_angle_cls
+        lang = postprocess_params.lang
+        assert lang in model_urls[
+            'rec'], 'param lang must in {}, but got {}'.format(
+                model_urls['rec'].keys(), lang)
+        if postprocess_params.rec_char_dict_path is None:
+            postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
+                'dict_path']
+
+        # init model dir
+        if postprocess_params.det_model_dir is None:
+            postprocess_params.det_model_dir = os.path.join(
+                BASE_DIR, '{}/det'.format(VERSION))
+        if postprocess_params.rec_model_dir is None:
+            postprocess_params.rec_model_dir = os.path.join(
+                BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
+        if postprocess_params.cls_model_dir is None:
+            postprocess_params.cls_model_dir = os.path.join(
+                BASE_DIR, '{}/cls'.format(VERSION))
+        print(postprocess_params)
+
+        # download model
+        maybe_download(postprocess_params.det_model_dir, model_urls['det'])
+        maybe_download(postprocess_params.rec_model_dir,
+                       model_urls['rec'][lang]['url'])
+        maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
+
+        if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
+            logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
+            sys.exit(0)
+        if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
+            logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
+            sys.exit(0)
+
+        postprocess_params.rec_char_dict_path = str(
+            Path(__file__).parent.parent / postprocess_params.rec_char_dict_path)
+
+        # init det_model and rec_model
+        super().__init__(postprocess_params)
+
+    def ocr(self, img, det=True, rec=True, cls=False):
+        """
+        ocr with paddleocr
+        args:
+            img: img for ocr, support ndarray, img_path and list or ndarray
+            det: use text detection or not, if false, only rec will be exec. default is True
+            rec: use text recognition or not, if false, only det will be exec. default is True
+        """
+        print(det, rec, cls)
+        assert isinstance(img, (np.ndarray, list, str))
+        if isinstance(img, list) and det == True:
+            logger.error('When input a list of images, det must be false')
+            exit(0)
+
+        self.use_angle_cls = cls
+        if isinstance(img, str):
+            # download net image
+            if img.startswith('http'):
+                download_with_progressbar(img, 'tmp.jpg')
+                img = 'tmp.jpg'
+            image_file = img
+            img, flag = check_and_read_gif(image_file)
+            if not flag:
+                with open(image_file, 'rb') as f:
+                    np_arr = np.frombuffer(f.read(), dtype=np.uint8)
+                    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
+            if img is None:
+                logger.error("error in loading image:{}".format(image_file))
+                return None
+        if isinstance(img, np.ndarray) and len(img.shape) == 2:
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+        if det and rec:
+            dt_boxes, rec_res = self.__call__(img)
+            print("paddleocr.py dt_boxes", len(dt_boxes))
+            print("paddleocr.py rec_res", len(rec_res))
+            return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
+        elif det and not rec:
+            dt_boxes, elapse = self.text_detector(img)
+            if dt_boxes is None:
+                return None
+            return [box.tolist() for box in dt_boxes]
+        else:
+            if not isinstance(img, list):
+                img = [img]
+            if self.use_angle_cls:
+                img, cls_res, elapse = self.text_classifier(img)
+                if not rec:
+                    return cls_res
+            rec_res, elapse = self.text_recognizer(img)
+            return rec_res
+
+
+def main(mMain=True):
+    # for cmd
+    args = parse_args(mMain)
+    # args = parse_args(mMain=True)
+
+    # 图片是网络的还是本地路径
+    image_dir = args.image_dir
+    if image_dir.startswith('http'):
+        download_with_progressbar(image_dir, 'tmp.jpg')
+        image_file_list = ['tmp.jpg']
+    else:
+        image_file_list = get_image_file_list(args.image_dir)
+    if len(image_file_list) == 0:
+        logger.error('no images find in {}'.format(args.image_dir))
+        return
+
+    ocr_engine = PaddleOCR(**(args.__dict__))
+    for img_path in image_file_list:
+        logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
+        result = ocr_engine.ocr(img_path,
+                                det=args.det,
+                                rec=args.rec,
+                                cls=args.use_angle_cls)
+        if result is not None:
+            for line in result:
+                logger.info(line)
+
+
+if __name__ == '__main__':
+    main(False)

+ 13 - 0
ocr/ppocr/__init__.py

@@ -0,0 +1,13 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 110 - 0
ocr/ppocr/data/__init__.py

@@ -0,0 +1,110 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+import sys
+import numpy as np
+import paddle
+import signal
+import random
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import copy
+from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
+# from paddle.fluid.io import DataLoader
+# from paddle.fluid.dataloader import Dataset, BatchSampler, DistributedBatchSampler
+import paddle.distributed as dist
+
+from ocr.ppocr.data.imaug import transform, create_operators
+from ocr.ppocr.data.simple_dataset import SimpleDataSet
+from ocr.ppocr.data.lmdb_dataset import LMDBDataSet
+
+__all__ = ['build_dataloader', 'transform', 'create_operators']
+
+
+def term_mp(sig_num, frame):
+    """ kill all child processes
+    """
+    pid = os.getpid()
+    pgid = os.getpgid(os.getpid())
+    print("main proc {} exit, kill process group " "{}".format(pid, pgid))
+    os.killpg(pgid, signal.SIGKILL)
+
+
+signal.signal(signal.SIGINT, term_mp)
+signal.signal(signal.SIGTERM, term_mp)
+
+
+def build_dataloader(config, mode, device, logger, seed=None):
+    config = copy.deepcopy(config)
+
+    # 从配置文件中读取相关配置,并判断是否包含在支持中
+    assert mode in ['Train', 'Eval', 'Test'
+                    ], "Mode should be Train, Eval or Test."
+    module_name = config[mode]['dataset']['name']
+
+    support_dict = ['SimpleDataSet', 'LMDBDataSet']
+    assert module_name in support_dict, Exception(
+        'DataSet only support {}'.format(support_dict))
+
+    # 初始化对应的Dataset类
+    # eval: 根据字符串调用同名类
+    # eval('SimpleDataSet')(config) = SimpleDataSet(config)
+    dataset = eval(module_name)(config, mode, logger, seed)
+
+    # 读取其他参数
+    loader_config = config[mode]['loader']
+    batch_size = loader_config['batch_size_per_card']
+    drop_last = loader_config['drop_last']
+    shuffle = loader_config['shuffle']
+    num_workers = loader_config['num_workers']
+    if 'use_shared_memory' in loader_config.keys():
+        use_shared_memory = loader_config['use_shared_memory']
+    else:
+        use_shared_memory = True
+
+    # Train模式,可多个GPU同时训练,Eval模式则是单卡,分多个Batch
+    if mode == "Train":
+        #Distribute data to multiple cards
+        batch_sampler = DistributedBatchSampler(
+            dataset=dataset,
+            batch_size=batch_size,
+            shuffle=shuffle,
+            drop_last=drop_last)
+        # random_epoch = np.random.randint(1, 10)
+        # batch_sampler.set_epoch(random_epoch)
+    else:
+        #Distribute data to single card
+        batch_sampler = BatchSampler(
+            dataset=dataset,
+            batch_size=batch_size,
+            shuffle=shuffle,
+            drop_last=drop_last)
+
+    # 根据已设参数,初始化数据集读取对象
+    data_loader = DataLoader(
+        dataset=dataset,
+        batch_sampler=batch_sampler,
+        places=device,
+        num_workers=num_workers,
+        return_list=True,
+        use_shared_memory=use_shared_memory)
+    return data_loader

+ 62 - 0
ocr/ppocr/data/imaug/__init__.py

@@ -0,0 +1,62 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from .iaa_augment import IaaAugment
+from .make_border_map import MakeBorderMap
+from .make_shrink_map import MakeShrinkMap
+from .random_crop_data import EastRandomCropData, PSERandomCrop
+
+from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
+from .randaugment import RandAugment
+from .operators import *
+from .label_ops import *
+
+from .east_process import *
+from .sast_process import *
+
+
+def transform(data, ops=None):
+    """ transform """
+    if ops is None:
+        ops = []
+    for op in ops:
+        data = op(data)
+        if data is None:
+            return None
+    return data
+
+
+def create_operators(op_param_list, global_config=None):
+    """
+    create operators based on the config
+
+    Args:
+        params(list): a dict list, used to create some operators
+    """
+    assert isinstance(op_param_list, list), ('operator config should be a list')
+    ops = []
+    for operator in op_param_list:
+        assert isinstance(operator,
+                          dict) and len(operator) == 1, "yaml format error"
+        op_name = list(operator)[0]
+        param = {} if operator[op_name] is None else operator[op_name]
+        if global_config is not None:
+            param.update(global_config)
+        op = eval(op_name)(**param)
+        ops.append(op)
+    return ops

+ 439 - 0
ocr/ppocr/data/imaug/east_process.py

@@ -0,0 +1,439 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import math
+import cv2
+import numpy as np
+import json
+import sys
+import os
+
+__all__ = ['EASTProcessTrain']
+
+
+class EASTProcessTrain(object):
+    def __init__(self,
+                 image_shape = [512, 512],
+                 background_ratio = 0.125,
+                 min_crop_side_ratio = 0.1,
+                 min_text_size = 10,
+                 **kwargs):
+        self.input_size = image_shape[1]
+        self.random_scale = np.array([0.5, 1, 2.0, 3.0])
+        self.background_ratio = background_ratio
+        self.min_crop_side_ratio = min_crop_side_ratio
+        self.min_text_size = min_text_size
+
+    def preprocess(self, im):
+        input_size = self.input_size
+        im_shape = im.shape
+        im_size_min = np.min(im_shape[0:2])
+        im_size_max = np.max(im_shape[0:2])
+        im_scale = float(input_size) / float(im_size_max)
+        im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
+        img_mean = [0.485, 0.456, 0.406]
+        img_std = [0.229, 0.224, 0.225]
+        # im = im[:, :, ::-1].astype(np.float32)
+        im = im / 255
+        im -= img_mean
+        im /= img_std
+        new_h, new_w, _ = im.shape
+        im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
+        im_padded[:new_h, :new_w, :] = im
+        im_padded = im_padded.transpose((2, 0, 1))
+        im_padded = im_padded[np.newaxis, :]
+        return im_padded, im_scale
+
+    def rotate_im_poly(self, im, text_polys):
+        """
+        rotate image with 90 / 180 / 270 degre
+        """
+        im_w, im_h = im.shape[1], im.shape[0]
+        dst_im = im.copy()
+        dst_polys = []
+        rand_degree_ratio = np.random.rand()
+        rand_degree_cnt = 1
+        if 0.333 < rand_degree_ratio < 0.666:
+            rand_degree_cnt = 2
+        elif rand_degree_ratio > 0.666:
+            rand_degree_cnt = 3
+        for i in range(rand_degree_cnt):
+            dst_im = np.rot90(dst_im)
+        rot_degree = -90 * rand_degree_cnt
+        rot_angle = rot_degree * math.pi / 180.0
+        n_poly = text_polys.shape[0]
+        cx, cy = 0.5 * im_w, 0.5 * im_h
+        ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
+        for i in range(n_poly):
+            wordBB = text_polys[i]
+            poly = []
+            for j in range(4):
+                sx, sy = wordBB[j][0], wordBB[j][1]
+                dx = math.cos(rot_angle) * (sx - cx)\
+                    - math.sin(rot_angle) * (sy - cy) + ncx
+                dy = math.sin(rot_angle) * (sx - cx)\
+                    + math.cos(rot_angle) * (sy - cy) + ncy
+                poly.append([dx, dy])
+            dst_polys.append(poly)
+        dst_polys = np.array(dst_polys, dtype=np.float32)
+        return dst_im, dst_polys
+
+    def polygon_area(self, poly):
+        """
+        compute area of a polygon
+        :param poly:
+        :return:
+        """
+        edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+                (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+                (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+                (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
+        return np.sum(edge) / 2.
+
+    def check_and_validate_polys(self, polys, tags, img_height, img_width):
+        """
+        check so that the text poly is in the same direction,
+        and also filter some invalid polygons
+        :param polys:
+        :param tags:
+        :return:
+        """
+        h, w = img_height, img_width
+        if polys.shape[0] == 0:
+            return polys
+        polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
+        polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
+
+        validated_polys = []
+        validated_tags = []
+        for poly, tag in zip(polys, tags):
+            p_area = self.polygon_area(poly)
+            #invalid poly
+            if abs(p_area) < 1:
+                continue
+            if p_area > 0:
+                #'poly in wrong direction'
+                if not tag:
+                    tag = True  #reversed cases should be ignore
+                poly = poly[(0, 3, 2, 1), :]
+            validated_polys.append(poly)
+            validated_tags.append(tag)
+        return np.array(validated_polys), np.array(validated_tags)
+
+    def draw_img_polys(self, img, polys):
+        if len(img.shape) == 4:
+            img = np.squeeze(img, axis=0)
+        if img.shape[0] == 3:
+            img = img.transpose((1, 2, 0))
+            img[:, :, 2] += 123.68
+            img[:, :, 1] += 116.78
+            img[:, :, 0] += 103.94
+        cv2.imwrite("tmp.jpg", img)
+        img = cv2.imread("tmp.jpg")
+        for box in polys:
+            box = box.astype(np.int32).reshape((-1, 1, 2))
+            cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
+        import random
+        ino = random.randint(0, 100)
+        cv2.imwrite("tmp_%d.jpg" % ino, img)
+        return
+
+    def shrink_poly(self, poly, r):
+        """
+        fit a poly inside the origin poly, maybe bugs here...
+        used for generate the score map
+        :param poly: the text poly
+        :param r: r in the paper
+        :return: the shrinked poly
+        """
+        # shrink ratio
+        R = 0.3
+        # find the longer pair
+        dist0 = np.linalg.norm(poly[0] - poly[1])
+        dist1 = np.linalg.norm(poly[2] - poly[3])
+        dist2 = np.linalg.norm(poly[0] - poly[3])
+        dist3 = np.linalg.norm(poly[1] - poly[2])
+        if dist0 + dist1 > dist2 + dist3:
+            # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
+            ## p0, p1
+            theta = np.arctan2((poly[1][1] - poly[0][1]),
+                               (poly[1][0] - poly[0][0]))
+            poly[0][0] += R * r[0] * np.cos(theta)
+            poly[0][1] += R * r[0] * np.sin(theta)
+            poly[1][0] -= R * r[1] * np.cos(theta)
+            poly[1][1] -= R * r[1] * np.sin(theta)
+            ## p2, p3
+            theta = np.arctan2((poly[2][1] - poly[3][1]),
+                               (poly[2][0] - poly[3][0]))
+            poly[3][0] += R * r[3] * np.cos(theta)
+            poly[3][1] += R * r[3] * np.sin(theta)
+            poly[2][0] -= R * r[2] * np.cos(theta)
+            poly[2][1] -= R * r[2] * np.sin(theta)
+            ## p0, p3
+            theta = np.arctan2((poly[3][0] - poly[0][0]),
+                               (poly[3][1] - poly[0][1]))
+            poly[0][0] += R * r[0] * np.sin(theta)
+            poly[0][1] += R * r[0] * np.cos(theta)
+            poly[3][0] -= R * r[3] * np.sin(theta)
+            poly[3][1] -= R * r[3] * np.cos(theta)
+            ## p1, p2
+            theta = np.arctan2((poly[2][0] - poly[1][0]),
+                               (poly[2][1] - poly[1][1]))
+            poly[1][0] += R * r[1] * np.sin(theta)
+            poly[1][1] += R * r[1] * np.cos(theta)
+            poly[2][0] -= R * r[2] * np.sin(theta)
+            poly[2][1] -= R * r[2] * np.cos(theta)
+        else:
+            ## p0, p3
+            # print poly
+            theta = np.arctan2((poly[3][0] - poly[0][0]),
+                               (poly[3][1] - poly[0][1]))
+            poly[0][0] += R * r[0] * np.sin(theta)
+            poly[0][1] += R * r[0] * np.cos(theta)
+            poly[3][0] -= R * r[3] * np.sin(theta)
+            poly[3][1] -= R * r[3] * np.cos(theta)
+            ## p1, p2
+            theta = np.arctan2((poly[2][0] - poly[1][0]),
+                               (poly[2][1] - poly[1][1]))
+            poly[1][0] += R * r[1] * np.sin(theta)
+            poly[1][1] += R * r[1] * np.cos(theta)
+            poly[2][0] -= R * r[2] * np.sin(theta)
+            poly[2][1] -= R * r[2] * np.cos(theta)
+            ## p0, p1
+            theta = np.arctan2((poly[1][1] - poly[0][1]),
+                               (poly[1][0] - poly[0][0]))
+            poly[0][0] += R * r[0] * np.cos(theta)
+            poly[0][1] += R * r[0] * np.sin(theta)
+            poly[1][0] -= R * r[1] * np.cos(theta)
+            poly[1][1] -= R * r[1] * np.sin(theta)
+            ## p2, p3
+            theta = np.arctan2((poly[2][1] - poly[3][1]),
+                               (poly[2][0] - poly[3][0]))
+            poly[3][0] += R * r[3] * np.cos(theta)
+            poly[3][1] += R * r[3] * np.sin(theta)
+            poly[2][0] -= R * r[2] * np.cos(theta)
+            poly[2][1] -= R * r[2] * np.sin(theta)
+        return poly
+
+    def generate_quad(self, im_size, polys, tags):
+        """
+        Generate quadrangle.
+        """
+        h, w = im_size
+        poly_mask = np.zeros((h, w), dtype=np.uint8)
+        score_map = np.zeros((h, w), dtype=np.uint8)
+        # (x1, y1, ..., x4, y4, short_edge_norm)
+        geo_map = np.zeros((h, w, 9), dtype=np.float32)
+        # mask used during traning, to ignore some hard areas
+        training_mask = np.ones((h, w), dtype=np.uint8)
+        for poly_idx, poly_tag in enumerate(zip(polys, tags)):
+            poly = poly_tag[0]
+            tag = poly_tag[1]
+
+            r = [None, None, None, None]
+            for i in range(4):
+                dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
+                dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
+                r[i] = min(dist1, dist2)
+            # score map
+            shrinked_poly = self.shrink_poly(
+                poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
+            cv2.fillPoly(score_map, shrinked_poly, 1)
+            cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
+            # if the poly is too small, then ignore it during training
+            poly_h = min(
+                np.linalg.norm(poly[0] - poly[3]),
+                np.linalg.norm(poly[1] - poly[2]))
+            poly_w = min(
+                np.linalg.norm(poly[0] - poly[1]),
+                np.linalg.norm(poly[2] - poly[3]))
+            if min(poly_h, poly_w) < self.min_text_size:
+                cv2.fillPoly(training_mask,
+                             poly.astype(np.int32)[np.newaxis, :, :], 0)
+
+            if tag:
+                cv2.fillPoly(training_mask,
+                             poly.astype(np.int32)[np.newaxis, :, :], 0)
+
+            xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
+            # geo map.
+            y_in_poly = xy_in_poly[:, 0]
+            x_in_poly = xy_in_poly[:, 1]
+            poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
+            poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
+            for pno in range(4):
+                geo_channel_beg = pno * 2
+                geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
+                    x_in_poly - poly[pno, 0]
+                geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
+                    y_in_poly - poly[pno, 1]
+            geo_map[y_in_poly, x_in_poly, 8] = \
+                1.0 / max(min(poly_h, poly_w), 1.0)
+        return score_map, geo_map, training_mask
+
+    def crop_area(self,
+                  im,
+                  polys,
+                  tags,
+                  crop_background=False,
+                  max_tries=50):
+        """
+        make random crop from the input image
+        :param im:
+        :param polys:
+        :param tags:
+        :param crop_background:
+        :param max_tries:
+        :return:
+        """
+        h, w, _ = im.shape
+        pad_h = h // 10
+        pad_w = w // 10
+        h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+        w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+        for poly in polys:
+            poly = np.round(poly, decimals=0).astype(np.int32)
+            minx = np.min(poly[:, 0])
+            maxx = np.max(poly[:, 0])
+            w_array[minx + pad_w:maxx + pad_w] = 1
+            miny = np.min(poly[:, 1])
+            maxy = np.max(poly[:, 1])
+            h_array[miny + pad_h:maxy + pad_h] = 1
+        # ensure the cropped area not across a text
+        h_axis = np.where(h_array == 0)[0]
+        w_axis = np.where(w_array == 0)[0]
+        if len(h_axis) == 0 or len(w_axis) == 0:
+            return im, polys, tags
+
+        for i in range(max_tries):
+            xx = np.random.choice(w_axis, size=2)
+            xmin = np.min(xx) - pad_w
+            xmax = np.max(xx) - pad_w
+            xmin = np.clip(xmin, 0, w - 1)
+            xmax = np.clip(xmax, 0, w - 1)
+            yy = np.random.choice(h_axis, size=2)
+            ymin = np.min(yy) - pad_h
+            ymax = np.max(yy) - pad_h
+            ymin = np.clip(ymin, 0, h - 1)
+            ymax = np.clip(ymax, 0, h - 1)
+            if xmax - xmin < self.min_crop_side_ratio * w or \
+               ymax - ymin < self.min_crop_side_ratio * h:
+                # area too small
+                continue
+            if polys.shape[0] != 0:
+                poly_axis_in_area = (polys[:, :, 0] >= xmin)\
+                    & (polys[:, :, 0] <= xmax)\
+                    & (polys[:, :, 1] >= ymin)\
+                    & (polys[:, :, 1] <= ymax)
+                selected_polys = np.where(
+                    np.sum(poly_axis_in_area, axis=1) == 4)[0]
+            else:
+                selected_polys = []
+
+            if len(selected_polys) == 0:
+                # no text in this area
+                if crop_background:
+                    im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+                    polys = []
+                    tags = []
+                    return im, polys, tags
+                else:
+                    continue
+
+            im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+            polys = polys[selected_polys]
+            tags = tags[selected_polys]
+            polys[:, :, 0] -= xmin
+            polys[:, :, 1] -= ymin
+            return im, polys, tags
+        return im, polys, tags
+
+    def crop_background_infor(self, im, text_polys, text_tags):
+        im, text_polys, text_tags = self.crop_area(
+            im, text_polys, text_tags, crop_background=True)
+
+        if len(text_polys) > 0:
+            return None
+        # pad and resize image
+        input_size = self.input_size
+        im, ratio = self.preprocess(im)
+        score_map = np.zeros((input_size, input_size), dtype=np.float32)
+        geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
+        training_mask = np.ones((input_size, input_size), dtype=np.float32)
+        return im, score_map, geo_map, training_mask
+
+    def crop_foreground_infor(self, im, text_polys, text_tags):
+        im, text_polys, text_tags = self.crop_area(
+            im, text_polys, text_tags, crop_background=False)
+
+        if text_polys.shape[0] == 0:
+            return None
+        #continue for all ignore case
+        if np.sum((text_tags * 1.0)) >= text_tags.size:
+            return None
+        # pad and resize image
+        input_size = self.input_size
+        im, ratio = self.preprocess(im)
+        text_polys[:, :, 0] *= ratio
+        text_polys[:, :, 1] *= ratio
+        _, _, new_h, new_w = im.shape
+        #         print(im.shape)
+        #         self.draw_img_polys(im, text_polys)
+        score_map, geo_map, training_mask = self.generate_quad(
+            (new_h, new_w), text_polys, text_tags)
+        return im, score_map, geo_map, training_mask
+
+    def __call__(self, data):
+        im = data['image']
+        text_polys = data['polys']
+        text_tags = data['ignore_tags']
+        if im is None:
+            return None
+        if text_polys.shape[0] == 0:
+            return None
+
+        #add rotate cases
+        if np.random.rand() < 0.5:
+            im, text_polys = self.rotate_im_poly(im, text_polys)
+        h, w, _ = im.shape
+        text_polys, text_tags = self.check_and_validate_polys(text_polys,
+                                                              text_tags, h, w)
+        if text_polys.shape[0] == 0:
+            return None
+
+        # random scale this image
+        rd_scale = np.random.choice(self.random_scale)
+        im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+        text_polys *= rd_scale
+        if np.random.rand() < self.background_ratio:
+            outs = self.crop_background_infor(im, text_polys, text_tags)
+        else:
+            outs = self.crop_foreground_infor(im, text_polys, text_tags)
+
+        if outs is None:
+            return None
+        im, score_map, geo_map, training_mask = outs
+        score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
+        geo_map = np.swapaxes(geo_map, 1, 2)
+        geo_map = np.swapaxes(geo_map, 1, 0)
+        geo_map = geo_map[:, ::4, ::4].astype(np.float32)
+        training_mask = training_mask[np.newaxis, ::4, ::4]
+        training_mask = training_mask.astype(np.float32)
+
+        data['image'] = im[0]
+        data['score_map'] = score_map
+        data['geo_map'] = geo_map
+        data['training_mask'] = training_mask
+        # print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
+        return data

+ 101 - 0
ocr/ppocr/data/imaug/iaa_augment.py

@@ -0,0 +1,101 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import imgaug
+import imgaug.augmenters as iaa
+
+
+class AugmenterBuilder(object):
+    def __init__(self):
+        pass
+
+    def build(self, args, root=True):
+        if args is None or len(args) == 0:
+            return None
+        elif isinstance(args, list):
+            if root:
+                sequence = [self.build(value, root=False) for value in args]
+                return iaa.Sequential(sequence)
+            else:
+                return getattr(iaa, args[0])(
+                    *[self.to_tuple_if_list(a) for a in args[1:]])
+        elif isinstance(args, dict):
+            cls = getattr(iaa, args['type'])
+            return cls(**{
+                k: self.to_tuple_if_list(v)
+                for k, v in args['args'].items()
+            })
+        else:
+            raise RuntimeError('unknown augmenter arg: ' + str(args))
+
+    def to_tuple_if_list(self, obj):
+        if isinstance(obj, list):
+            return tuple(obj)
+        return obj
+
+
+class IaaAugment():
+    def __init__(self, augmenter_args=None, **kwargs):
+        if augmenter_args is None:
+            augmenter_args = [{
+                'type': 'Fliplr',
+                'args': {
+                    'p': 0.5
+                }
+            }, {
+                'type': 'Affine',
+                'args': {
+                    'rotate': [-10, 10]
+                }
+            }, {
+                'type': 'Resize',
+                'args': {
+                    'size': [0.5, 3]
+                }
+            }]
+        self.augmenter = AugmenterBuilder().build(augmenter_args)
+
+    def __call__(self, data):
+        image = data['image']
+        shape = image.shape
+
+        if self.augmenter:
+            aug = self.augmenter.to_deterministic()
+            data['image'] = aug.augment_image(image)
+            data = self.may_augment_annotation(aug, data, shape)
+        return data
+
+    def may_augment_annotation(self, aug, data, shape):
+        if aug is None:
+            return data
+
+        line_polys = []
+        for poly in data['polys']:
+            new_poly = self.may_augment_poly(aug, shape, poly)
+            line_polys.append(new_poly)
+        data['polys'] = np.array(line_polys)
+        return data
+
+    def may_augment_poly(self, aug, img_shape, poly):
+        keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
+        keypoints = aug.augment_keypoints(
+            [imgaug.KeypointsOnImage(
+                keypoints, shape=img_shape)])[0].keypoints
+        poly = [(p.x, p.y) for p in keypoints]
+        return poly

+ 281 - 0
ocr/ppocr/data/imaug/label_ops.py

@@ -0,0 +1,281 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import string
+
+
+class ClsLabelEncode(object):
+    def __init__(self, label_list, **kwargs):
+        self.label_list = label_list
+
+    def __call__(self, data):
+        label = data['label']
+        if label not in self.label_list:
+            return None
+        label = self.label_list.index(label)
+        data['label'] = label
+        return data
+
+
+class DetLabelEncode(object):
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, data):
+        import json
+        label = data['label']
+        label = json.loads(label)
+        nBox = len(label)
+        boxes, txts, txt_tags = [], [], []
+        for bno in range(0, nBox):
+            box = label[bno]['points']
+            txt = label[bno]['transcription']
+            boxes.append(box)
+            txts.append(txt)
+            if txt in ['*', '###']:
+                txt_tags.append(True)
+            else:
+                txt_tags.append(False)
+        boxes = self.expand_points_num(boxes)
+        boxes = np.array(boxes, dtype=np.float32)
+        txt_tags = np.array(txt_tags, dtype=np.bool)
+
+        data['polys'] = boxes
+        data['texts'] = txts
+        data['ignore_tags'] = txt_tags
+        return data
+
+    def order_points_clockwise(self, pts):
+        rect = np.zeros((4, 2), dtype="float32")
+        s = pts.sum(axis=1)
+        rect[0] = pts[np.argmin(s)]
+        rect[2] = pts[np.argmax(s)]
+        diff = np.diff(pts, axis=1)
+        rect[1] = pts[np.argmin(diff)]
+        rect[3] = pts[np.argmax(diff)]
+        return rect
+
+    def expand_points_num(self, boxes):
+        max_points_num = 0
+        for box in boxes:
+            if len(box) > max_points_num:
+                max_points_num = len(box)
+        ex_boxes = []
+        for box in boxes:
+            ex_box = box + [box[-1]] * (max_points_num - len(box))
+            ex_boxes.append(ex_box)
+        return ex_boxes
+
+
+class BaseRecLabelEncode(object):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length,
+                 character_dict_path=None,
+                 character_type='ch',
+                 use_space_char=False):
+        support_character_type = [
+            'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
+            'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
+            'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
+            'mr', 'ne'
+        ]
+        assert character_type in support_character_type, "Only {} are supported now but get {}".format(
+            support_character_type, character_type)
+
+        self.max_text_len = max_text_length
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        if character_type == "en":
+            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
+            dict_character = list(self.character_str)
+        elif character_type == "EN_symbol":
+            # same with ASTER setting (use 94 char).
+            self.character_str = string.printable[:-6]
+            dict_character = list(self.character_str)
+        elif character_type in support_character_type:
+            self.character_str = ""
+            assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
+                character_type)
+            with open(character_dict_path, "rb") as fin:
+                lines = fin.readlines()
+                for line in lines:
+                    line = line.decode('utf-8').strip("\n").strip("\r\n")
+                    self.character_str += line
+            if use_space_char:
+                self.character_str += " "
+            dict_character = list(self.character_str)
+        self.character_type = character_type
+        dict_character = self.add_special_char(dict_character)
+        self.dict = {}
+        for i, char in enumerate(dict_character):
+            self.dict[char] = i
+        self.character = dict_character
+
+    def add_special_char(self, dict_character):
+        return dict_character
+
+    def encode(self, text):
+        """convert text-label into text-index.
+        input:
+            text: text labels of each image. [batch_size]
+
+        output:
+            text: concatenated text index for CTCLoss.
+                    [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+            length: length of each text. [batch_size]
+        """
+        if len(text) == 0 or len(text) > self.max_text_len:
+            return None
+        if self.character_type == "en":
+            text = text.lower()
+        text_list = []
+        for char in text:
+            if char not in self.dict:
+                # logger = get_logger()
+                # logger.warning('{} is not in dict'.format(char))
+                continue
+            text_list.append(self.dict[char])
+        if len(text_list) == 0:
+            return None
+        return text_list
+
+
+class CTCLabelEncode(BaseRecLabelEncode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length,
+                 character_dict_path=None,
+                 character_type='ch',
+                 use_space_char=False,
+                 **kwargs):
+        super(CTCLabelEncode,
+              self).__init__(max_text_length, character_dict_path,
+                             character_type, use_space_char)
+
+    def __call__(self, data):
+        text = data['label']
+        text = self.encode(text)
+        if text is None:
+            return None
+        data['length'] = np.array(len(text))
+        text = text + [0] * (self.max_text_len - len(text))
+        data['label'] = np.array(text)
+        return data
+
+    def add_special_char(self, dict_character):
+        dict_character = ['blank'] + dict_character
+        return dict_character
+
+
+class AttnLabelEncode(BaseRecLabelEncode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length,
+                 character_dict_path=None,
+                 character_type='ch',
+                 use_space_char=False,
+                 **kwargs):
+        super(AttnLabelEncode,
+              self).__init__(max_text_length, character_dict_path,
+                             character_type, use_space_char)
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+    def __call__(self, data):
+        text = data['label']
+        text = self.encode(text)
+        if text is None:
+            return None
+        if len(text) >= self.max_text_len:
+            return None
+        data['length'] = np.array(len(text))
+        text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
+                                                               - len(text) - 2)
+        data['label'] = np.array(text)
+        return data
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+                          % beg_or_end
+        return idx
+
+
+class SRNLabelEncode(BaseRecLabelEncode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length=25,
+                 character_dict_path=None,
+                 character_type='en',
+                 use_space_char=False,
+                 **kwargs):
+        super(SRNLabelEncode,
+              self).__init__(max_text_length, character_dict_path,
+                             character_type, use_space_char)
+
+    def add_special_char(self, dict_character):
+        dict_character = dict_character + [self.beg_str, self.end_str]
+        return dict_character
+
+    def __call__(self, data):
+        text = data['label']
+        text = self.encode(text)
+        char_num = len(self.character)
+        if text is None:
+            return None
+        if len(text) > self.max_text_len:
+            return None
+        data['length'] = np.array(len(text))
+        text = text + [char_num - 1] * (self.max_text_len - len(text))
+        data['label'] = np.array(text)
+        return data
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+                          % beg_or_end
+        return idx

+ 157 - 0
ocr/ppocr/data/imaug/make_border_map.py

@@ -0,0 +1,157 @@
+# -*- coding:utf-8 -*- 
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import cv2
+
+np.seterr(divide='ignore', invalid='ignore')
+import pyclipper
+from shapely.geometry import Polygon
+import sys
+import warnings
+
+warnings.simplefilter("ignore")
+
+__all__ = ['MakeBorderMap']
+
+
+class MakeBorderMap(object):
+    def __init__(self,
+                 shrink_ratio=0.4,
+                 thresh_min=0.3,
+                 thresh_max=0.7,
+                 **kwargs):
+        self.shrink_ratio = shrink_ratio
+        self.thresh_min = thresh_min
+        self.thresh_max = thresh_max
+
+    def __call__(self, data):
+
+        img = data['image']
+        text_polys = data['polys']
+        ignore_tags = data['ignore_tags']
+
+        canvas = np.zeros(img.shape[:2], dtype=np.float32)
+        mask = np.zeros(img.shape[:2], dtype=np.float32)
+
+        for i in range(len(text_polys)):
+            if ignore_tags[i]:
+                continue
+            self.draw_border_map(text_polys[i], canvas, mask=mask)
+        canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
+
+        data['threshold_map'] = canvas
+        data['threshold_mask'] = mask
+        return data
+
+    def draw_border_map(self, polygon, canvas, mask):
+        polygon = np.array(polygon)
+        assert polygon.ndim == 2
+        assert polygon.shape[1] == 2
+
+        polygon_shape = Polygon(polygon)
+        if polygon_shape.area <= 0:
+            return
+        distance = polygon_shape.area * (
+            1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+        subject = [tuple(l) for l in polygon]
+        padding = pyclipper.PyclipperOffset()
+        padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+
+        padded_polygon = np.array(padding.Execute(distance)[0])
+        cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
+
+        xmin = padded_polygon[:, 0].min()
+        xmax = padded_polygon[:, 0].max()
+        ymin = padded_polygon[:, 1].min()
+        ymax = padded_polygon[:, 1].max()
+        width = xmax - xmin + 1
+        height = ymax - ymin + 1
+
+        polygon[:, 0] = polygon[:, 0] - xmin
+        polygon[:, 1] = polygon[:, 1] - ymin
+
+        xs = np.broadcast_to(
+            np.linspace(
+                0, width - 1, num=width).reshape(1, width), (height, width))
+        ys = np.broadcast_to(
+            np.linspace(
+                0, height - 1, num=height).reshape(height, 1), (height, width))
+
+        distance_map = np.zeros(
+            (polygon.shape[0], height, width), dtype=np.float32)
+        for i in range(polygon.shape[0]):
+            j = (i + 1) % polygon.shape[0]
+            absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
+            distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
+        distance_map = distance_map.min(axis=0)
+
+        xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
+        xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
+        ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
+        ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
+        canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
+            1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
+                             xmin_valid - xmin:xmax_valid - xmax + width],
+            canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
+
+    def _distance(self, xs, ys, point_1, point_2):
+        '''
+        compute the distance from point to a line
+        ys: coordinates in the first axis
+        xs: coordinates in the second axis
+        point_1, point_2: (x, y), the end of the line
+        '''
+        height, width = xs.shape[:2]
+        square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
+            1])
+        square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
+            1])
+        square_distance = np.square(point_1[0] - point_2[0]) + np.square(
+            point_1[1] - point_2[1])
+
+        cosin = (square_distance - square_distance_1 - square_distance_2) / (
+            2 * np.sqrt(square_distance_1 * square_distance_2))
+        square_sin = 1 - np.square(cosin)
+        square_sin = np.nan_to_num(square_sin)
+        result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
+                         square_distance)
+
+        result[cosin <
+               0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
+                                                                           < 0]
+        # self.extend_line(point_1, point_2, result)
+        return result
+
+    def extend_line(self, point_1, point_2, result, shrink_ratio):
+        ex_point_1 = (int(
+            round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
+                      int(
+                          round(point_1[1] + (point_1[1] - point_2[1]) * (
+                              1 + shrink_ratio))))
+        cv2.line(
+            result,
+            tuple(ex_point_1),
+            tuple(point_1),
+            4096.0,
+            1,
+            lineType=cv2.LINE_AA,
+            shift=0)
+        ex_point_2 = (int(
+            round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
+                      int(
+                          round(point_2[1] + (point_2[1] - point_1[1]) * (
+                              1 + shrink_ratio))))
+        cv2.line(
+            result,
+            tuple(ex_point_2),
+            tuple(point_2),
+            4096.0,
+            1,
+            lineType=cv2.LINE_AA,
+            shift=0)
+        return ex_point_1, ex_point_2

+ 107 - 0
ocr/ppocr/data/imaug/make_shrink_map.py

@@ -0,0 +1,107 @@
+# -*- coding:utf-8 -*- 
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import cv2
+from shapely.geometry import Polygon
+import pyclipper
+
+__all__ = ['MakeShrinkMap']
+
+
+class MakeShrinkMap(object):
+    r'''
+    Making binary mask from detection data with ICDAR format.
+    Typically following the process of class `MakeICDARData`.
+    '''
+
+    def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
+        self.min_text_size = min_text_size
+        self.shrink_ratio = shrink_ratio
+
+    def __call__(self, data):
+        image = data['image']
+        text_polys = data['polys']
+        ignore_tags = data['ignore_tags']
+
+        h, w = image.shape[:2]
+        text_polys, ignore_tags = self.validate_polygons(text_polys,
+                                                         ignore_tags, h, w)
+        gt = np.zeros((h, w), dtype=np.float32)
+        mask = np.ones((h, w), dtype=np.float32)
+        for i in range(len(text_polys)):
+            polygon = text_polys[i]
+            height = max(polygon[:, 1]) - min(polygon[:, 1])
+            width = max(polygon[:, 0]) - min(polygon[:, 0])
+            if ignore_tags[i] or min(height, width) < self.min_text_size:
+                cv2.fillPoly(mask,
+                             polygon.astype(np.int32)[np.newaxis, :, :], 0)
+                ignore_tags[i] = True
+            else:
+                polygon_shape = Polygon(polygon)
+                subject = [tuple(l) for l in polygon]
+                padding = pyclipper.PyclipperOffset()
+                padding.AddPath(subject, pyclipper.JT_ROUND,
+                                pyclipper.ET_CLOSEDPOLYGON)
+                shrinked = []
+
+                # Increase the shrink ratio every time we get multiple polygon returned back 
+                possible_ratios = np.arange(self.shrink_ratio, 1,
+                                            self.shrink_ratio)
+                np.append(possible_ratios, 1)
+                # print(possible_ratios)
+                for ratio in possible_ratios:
+                    # print(f"Change shrink ratio to {ratio}")
+                    distance = polygon_shape.area * (
+                        1 - np.power(ratio, 2)) / polygon_shape.length
+                    shrinked = padding.Execute(-distance)
+                    if len(shrinked) == 1:
+                        break
+
+                if shrinked == []:
+                    cv2.fillPoly(mask,
+                                 polygon.astype(np.int32)[np.newaxis, :, :], 0)
+                    ignore_tags[i] = True
+                    continue
+
+                for each_shirnk in shrinked:
+                    shirnk = np.array(each_shirnk).reshape(-1, 2)
+                    cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
+                # cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
+
+        data['shrink_map'] = gt
+        data['shrink_mask'] = mask
+        return data
+
+    def validate_polygons(self, polygons, ignore_tags, h, w):
+        '''
+        polygons (numpy.array, required): of shape (num_instances, num_points, 2)
+        '''
+        if len(polygons) == 0:
+            return polygons, ignore_tags
+        assert len(polygons) == len(ignore_tags)
+        for polygon in polygons:
+            polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
+            polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
+
+        for i in range(len(polygons)):
+            area = self.polygon_area(polygons[i])
+            if abs(area) < 1:
+                ignore_tags[i] = True
+            if area > 0:
+                polygons[i] = polygons[i][::-1, :]
+        return polygons, ignore_tags
+
+    def polygon_area(self, polygon):
+        # return cv2.contourArea(polygon.astype(np.float32))
+        edge = 0
+        for i in range(polygon.shape[0]):
+            next_index = (i + 1) % polygon.shape[0]
+            edge += (polygon[next_index, 0] - polygon[i, 0]) * (
+                polygon[next_index, 1] - polygon[i, 1])
+
+        return edge / 2.

+ 225 - 0
ocr/ppocr/data/imaug/operators.py

@@ -0,0 +1,225 @@
+"""
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import six
+import cv2
+import numpy as np
+
+
+class DecodeImage(object):
+    """ decode image """
+
+    def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
+        self.img_mode = img_mode
+        self.channel_first = channel_first
+
+    def __call__(self, data):
+        img = data['image']
+        if six.PY2:
+            assert type(img) is str and len(
+                img) > 0, "invalid input 'img' in DecodeImage"
+        else:
+            assert type(img) is bytes and len(
+                img) > 0, "invalid input 'img' in DecodeImage"
+        img = np.frombuffer(img, dtype='uint8')
+        img = cv2.imdecode(img, 1)
+        if img is None:
+            return None
+        if self.img_mode == 'GRAY':
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+        elif self.img_mode == 'RGB':
+            assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+            img = img[:, :, ::-1]
+
+        if self.channel_first:
+            img = img.transpose((2, 0, 1))
+
+        data['image'] = img
+        return data
+
+
+class NormalizeImage(object):
+    """ normalize image such as substract mean, divide std
+    """
+
+    def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        mean = mean if mean is not None else [0.485, 0.456, 0.406]
+        std = std if std is not None else [0.229, 0.224, 0.225]
+
+        shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype('float32')
+        self.std = np.array(std).reshape(shape).astype('float32')
+
+    def __call__(self, data):
+        img = data['image']
+        from PIL import Image
+        if isinstance(img, Image.Image):
+            img = np.array(img)
+
+        assert isinstance(img,
+                          np.ndarray), "invalid input 'img' in NormalizeImage"
+        data['image'] = (
+            img.astype('float32') * self.scale - self.mean) / self.std
+        return data
+
+
+class ToCHWImage(object):
+    """ convert hwc image to chw image
+    """
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, data):
+        img = data['image']
+        from PIL import Image
+        if isinstance(img, Image.Image):
+            img = np.array(img)
+        data['image'] = img.transpose((2, 0, 1))
+        return data
+
+
+class KeepKeys(object):
+    def __init__(self, keep_keys, **kwargs):
+        self.keep_keys = keep_keys
+
+    def __call__(self, data):
+        data_list = []
+        for key in self.keep_keys:
+            data_list.append(data[key])
+        return data_list
+
+
+class DetResizeForTest(object):
+    def __init__(self, **kwargs):
+        super(DetResizeForTest, self).__init__()
+        self.resize_type = 0
+        if 'image_shape' in kwargs:
+            self.image_shape = kwargs['image_shape']
+            self.resize_type = 1
+        elif 'limit_side_len' in kwargs:
+            self.limit_side_len = kwargs['limit_side_len']
+            self.limit_type = kwargs.get('limit_type', 'min')
+        elif 'resize_long' in kwargs:
+            self.resize_type = 2
+            self.resize_long = kwargs.get('resize_long', 960)
+        else:
+            self.limit_side_len = 736
+            self.limit_type = 'min'
+
+    def __call__(self, data):
+        img = data['image']
+        src_h, src_w, _ = img.shape
+
+        if self.resize_type == 0:
+            # img, shape = self.resize_image_type0(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+        elif self.resize_type == 2:
+            img, [ratio_h, ratio_w] = self.resize_image_type2(img)
+        else:
+            # img, shape = self.resize_image_type1(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type1(img)
+        data['image'] = img
+        data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+        return data
+
+    def resize_image_type1(self, img):
+        resize_h, resize_w = self.image_shape
+        ori_h, ori_w = img.shape[:2]  # (h, w, c)
+        ratio_h = float(resize_h) / ori_h
+        ratio_w = float(resize_w) / ori_w
+        img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        # return img, np.array([ori_h, ori_w])
+        return img, [ratio_h, ratio_w]
+
+    def resize_image_type0(self, img):
+        """
+        resize image to a size multiple of 32 which is required by the network
+        args:
+            img(array): array with shape [h, w, c]
+        return(tuple):
+            img, (ratio_h, ratio_w)
+        """
+        limit_side_len = self.limit_side_len
+        h, w, _ = img.shape
+
+        # limit the max side
+        if self.limit_type == 'max':
+            if max(h, w) > limit_side_len:
+                if h > w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.
+        else:
+            if min(h, w) < limit_side_len:
+                if h < w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+
+        resize_h = int(round(resize_h / 32) * 32)
+        resize_w = int(round(resize_w / 32) * 32)
+
+        try:
+            if int(resize_w) <= 0 or int(resize_h) <= 0:
+                return None, (None, None)
+            img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        except:
+            print(img.shape, resize_w, resize_h)
+            sys.exit(0)
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+        # return img, np.array([h, w])
+        return img, [ratio_h, ratio_w]
+
+    def resize_image_type2(self, img):
+        h, w, _ = img.shape
+
+        resize_w = w
+        resize_h = h
+
+        # Fix the longer side
+        if resize_h > resize_w:
+            ratio = float(self.resize_long) / resize_h
+        else:
+            ratio = float(self.resize_long) / resize_w
+
+        resize_h = int(resize_h * ratio)
+        resize_w = int(resize_w * ratio)
+
+        max_stride = 128
+        resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+        resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+        img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+
+        return img, [ratio_h, ratio_w]

+ 140 - 0
ocr/ppocr/data/imaug/randaugment.py

@@ -0,0 +1,140 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from PIL import Image, ImageEnhance, ImageOps
+import numpy as np
+import random
+import six
+
+
+class RawRandAugment(object):
+    def __init__(self,
+                 num_layers=2,
+                 magnitude=5,
+                 fillcolor=(128, 128, 128),
+                 **kwargs):
+        self.num_layers = num_layers
+        self.magnitude = magnitude
+        self.max_level = 10
+
+        abso_level = self.magnitude / self.max_level
+        self.level_map = {
+            "shearX": 0.3 * abso_level,
+            "shearY": 0.3 * abso_level,
+            "translateX": 150.0 / 331 * abso_level,
+            "translateY": 150.0 / 331 * abso_level,
+            "rotate": 30 * abso_level,
+            "color": 0.9 * abso_level,
+            "posterize": int(4.0 * abso_level),
+            "solarize": 256.0 * abso_level,
+            "contrast": 0.9 * abso_level,
+            "sharpness": 0.9 * abso_level,
+            "brightness": 0.9 * abso_level,
+            "autocontrast": 0,
+            "equalize": 0,
+            "invert": 0
+        }
+
+        # from https://stackoverflow.com/questions/5252170/
+        # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
+        def rotate_with_fill(img, magnitude):
+            rot = img.convert("RGBA").rotate(magnitude)
+            return Image.composite(rot,
+                                   Image.new("RGBA", rot.size, (128, ) * 4),
+                                   rot).convert(img.mode)
+
+        rnd_ch_op = random.choice
+
+        self.func = {
+            "shearX": lambda img, magnitude: img.transform(
+                img.size,
+                Image.AFFINE,
+                (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
+                Image.BICUBIC,
+                fillcolor=fillcolor),
+            "shearY": lambda img, magnitude: img.transform(
+                img.size,
+                Image.AFFINE,
+                (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
+                Image.BICUBIC,
+                fillcolor=fillcolor),
+            "translateX": lambda img, magnitude: img.transform(
+                img.size,
+                Image.AFFINE,
+                (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
+                fillcolor=fillcolor),
+            "translateY": lambda img, magnitude: img.transform(
+                img.size,
+                Image.AFFINE,
+                (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
+                fillcolor=fillcolor),
+            "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
+            "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
+                1 + magnitude * rnd_ch_op([-1, 1])),
+            "posterize": lambda img, magnitude:
+            ImageOps.posterize(img, magnitude),
+            "solarize": lambda img, magnitude:
+            ImageOps.solarize(img, magnitude),
+            "contrast": lambda img, magnitude:
+            ImageEnhance.Contrast(img).enhance(
+                1 + magnitude * rnd_ch_op([-1, 1])),
+            "sharpness": lambda img, magnitude:
+            ImageEnhance.Sharpness(img).enhance(
+                1 + magnitude * rnd_ch_op([-1, 1])),
+            "brightness": lambda img, magnitude:
+            ImageEnhance.Brightness(img).enhance(
+                1 + magnitude * rnd_ch_op([-1, 1])),
+            "autocontrast": lambda img, magnitude:
+            ImageOps.autocontrast(img),
+            "equalize": lambda img, magnitude: ImageOps.equalize(img),
+            "invert": lambda img, magnitude: ImageOps.invert(img)
+        }
+
+    def __call__(self, img):
+        avaiable_op_names = list(self.level_map.keys())
+        for layer_num in range(self.num_layers):
+            op_name = np.random.choice(avaiable_op_names)
+            img = self.func[op_name](img, self.level_map[op_name])
+        return img
+
+
+class RandAugment(RawRandAugment):
+    """ RandAugment wrapper to auto fit different img types """
+
+    def __init__(self, *args, **kwargs):
+        if six.PY2:
+            super(RandAugment, self).__init__(*args, **kwargs)
+        else:
+            super().__init__(*args, **kwargs)
+
+    def __call__(self, data):
+        img = data['image']
+        if not isinstance(img, Image.Image):
+            img = np.ascontiguousarray(img)
+            img = Image.fromarray(img)
+
+        if six.PY2:
+            img = super(RandAugment, self).__call__(img)
+        else:
+            img = super().__call__(img)
+
+        if isinstance(img, Image.Image):
+            img = np.asarray(img)
+        data['image'] = img
+        return data

+ 210 - 0
ocr/ppocr/data/imaug/random_crop_data.py

@@ -0,0 +1,210 @@
+# -*- coding:utf-8 -*- 
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import cv2
+import random
+
+
+def is_poly_in_rect(poly, x, y, w, h):
+    poly = np.array(poly)
+    if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+        return False
+    if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+        return False
+    return True
+
+
+def is_poly_outside_rect(poly, x, y, w, h):
+    poly = np.array(poly)
+    if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+        return True
+    if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+        return True
+    return False
+
+
+def split_regions(axis):
+    regions = []
+    min_axis = 0
+    for i in range(1, axis.shape[0]):
+        if axis[i] != axis[i - 1] + 1:
+            region = axis[min_axis:i]
+            min_axis = i
+            regions.append(region)
+    return regions
+
+
+def random_select(axis, max_size):
+    xx = np.random.choice(axis, size=2)
+    xmin = np.min(xx)
+    xmax = np.max(xx)
+    xmin = np.clip(xmin, 0, max_size - 1)
+    xmax = np.clip(xmax, 0, max_size - 1)
+    return xmin, xmax
+
+
+def region_wise_random_select(regions, max_size):
+    selected_index = list(np.random.choice(len(regions), 2))
+    selected_values = []
+    for index in selected_index:
+        axis = regions[index]
+        xx = int(np.random.choice(axis, size=1))
+        selected_values.append(xx)
+    xmin = min(selected_values)
+    xmax = max(selected_values)
+    return xmin, xmax
+
+
+def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
+    h, w, _ = im.shape
+    h_array = np.zeros(h, dtype=np.int32)
+    w_array = np.zeros(w, dtype=np.int32)
+    for points in text_polys:
+        points = np.round(points, decimals=0).astype(np.int32)
+        minx = np.min(points[:, 0])
+        maxx = np.max(points[:, 0])
+        w_array[minx:maxx] = 1
+        miny = np.min(points[:, 1])
+        maxy = np.max(points[:, 1])
+        h_array[miny:maxy] = 1
+    # ensure the cropped area not across a text
+    h_axis = np.where(h_array == 0)[0]
+    w_axis = np.where(w_array == 0)[0]
+
+    if len(h_axis) == 0 or len(w_axis) == 0:
+        return 0, 0, w, h
+
+    h_regions = split_regions(h_axis)
+    w_regions = split_regions(w_axis)
+
+    for i in range(max_tries):
+        if len(w_regions) > 1:
+            xmin, xmax = region_wise_random_select(w_regions, w)
+        else:
+            xmin, xmax = random_select(w_axis, w)
+        if len(h_regions) > 1:
+            ymin, ymax = region_wise_random_select(h_regions, h)
+        else:
+            ymin, ymax = random_select(h_axis, h)
+
+        if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
+            # area too small
+            continue
+        num_poly_in_rect = 0
+        for poly in text_polys:
+            if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
+                                        ymax - ymin):
+                num_poly_in_rect += 1
+                break
+
+        if num_poly_in_rect > 0:
+            return xmin, ymin, xmax - xmin, ymax - ymin
+
+    return 0, 0, w, h
+
+
+class EastRandomCropData(object):
+    def __init__(self,
+                 size=(640, 640),
+                 max_tries=10,
+                 min_crop_side_ratio=0.1,
+                 keep_ratio=True,
+                 **kwargs):
+        self.size = size
+        self.max_tries = max_tries
+        self.min_crop_side_ratio = min_crop_side_ratio
+        self.keep_ratio = keep_ratio
+
+    def __call__(self, data):
+        img = data['image']
+        text_polys = data['polys']
+        ignore_tags = data['ignore_tags']
+        texts = data['texts']
+        all_care_polys = [
+            text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
+        ]
+        # 计算crop区域
+        crop_x, crop_y, crop_w, crop_h = crop_area(
+            img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
+        # crop 图片 保持比例填充
+        scale_w = self.size[0] / crop_w
+        scale_h = self.size[1] / crop_h
+        scale = min(scale_w, scale_h)
+        h = int(crop_h * scale)
+        w = int(crop_w * scale)
+        if self.keep_ratio:
+            padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
+                              img.dtype)
+            padimg[:h, :w] = cv2.resize(
+                img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+            img = padimg
+        else:
+            img = cv2.resize(
+                img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
+                tuple(self.size))
+        # crop 文本框
+        text_polys_crop = []
+        ignore_tags_crop = []
+        texts_crop = []
+        for poly, text, tag in zip(text_polys, texts, ignore_tags):
+            poly = ((poly - (crop_x, crop_y)) * scale).tolist()
+            if not is_poly_outside_rect(poly, 0, 0, w, h):
+                text_polys_crop.append(poly)
+                ignore_tags_crop.append(tag)
+                texts_crop.append(text)
+        data['image'] = img
+        data['polys'] = np.array(text_polys_crop)
+        data['ignore_tags'] = ignore_tags_crop
+        data['texts'] = texts_crop
+        return data
+
+
+class PSERandomCrop(object):
+    def __init__(self, size, **kwargs):
+        self.size = size
+
+    def __call__(self, data):
+        imgs = data['imgs']
+
+        h, w = imgs[0].shape[0:2]
+        th, tw = self.size
+        if w == tw and h == th:
+            return imgs
+
+        # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
+        if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
+            # 文本实例的左上角点
+            tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
+            tl[tl < 0] = 0
+            # 文本实例的右下角点
+            br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
+            br[br < 0] = 0
+            # 保证选到右下角点时,有足够的距离进行crop
+            br[0] = min(br[0], h - th)
+            br[1] = min(br[1], w - tw)
+
+            for _ in range(50000):
+                i = random.randint(tl[0], br[0])
+                j = random.randint(tl[1], br[1])
+                # 保证shrink_label_map有文本
+                if imgs[1][i:i + th, j:j + tw].sum() <= 0:
+                    continue
+                else:
+                    break
+        else:
+            i = random.randint(0, h - th)
+            j = random.randint(0, w - tw)
+
+        # return i, j, th, tw
+        for idx in range(len(imgs)):
+            if len(imgs[idx].shape) == 3:
+                imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
+            else:
+                imgs[idx] = imgs[idx][i:i + th, j:j + tw]
+        data['imgs'] = imgs
+        return data

+ 435 - 0
ocr/ppocr/data/imaug/rec_img_aug.py

@@ -0,0 +1,435 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import cv2
+import numpy as np
+import random
+
+from .text_image_aug import tia_perspective, tia_stretch, tia_distort
+
+
+class RecAug(object):
+    def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
+        self.use_tia = use_tia
+        self.aug_prob = aug_prob
+
+    def __call__(self, data):
+        img = data['image']
+        img = warp(img, 10, self.use_tia, self.aug_prob)
+        data['image'] = img
+        return data
+
+
+class ClsResizeImg(object):
+    def __init__(self, image_shape, **kwargs):
+        self.image_shape = image_shape
+
+    def __call__(self, data):
+        img = data['image']
+        norm_img = resize_norm_img(img, self.image_shape)
+        data['image'] = norm_img
+        return data
+
+
+class RecResizeImg(object):
+    def __init__(self,
+                 image_shape,
+                 infer_mode=False,
+                 character_type='ch',
+                 **kwargs):
+        self.image_shape = image_shape
+        self.infer_mode = infer_mode
+        self.character_type = character_type
+
+    def __call__(self, data):
+        img = data['image']
+        if self.infer_mode and self.character_type == "ch":
+            norm_img = resize_norm_img_chinese(img, self.image_shape)
+        else:
+            norm_img = resize_norm_img(img, self.image_shape)
+        data['image'] = norm_img
+        return data
+
+
+class SRNRecResizeImg(object):
+    def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
+        self.image_shape = image_shape
+        self.num_heads = num_heads
+        self.max_text_length = max_text_length
+
+    def __call__(self, data):
+        img = data['image']
+        norm_img = resize_norm_img_srn(img, self.image_shape)
+        data['image'] = norm_img
+        [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+            srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
+
+        data['encoder_word_pos'] = encoder_word_pos
+        data['gsrm_word_pos'] = gsrm_word_pos
+        data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
+        data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
+        return data
+
+
+def resize_norm_img(img, image_shape):
+    imgC, imgH, imgW = image_shape
+    h = img.shape[0]
+    w = img.shape[1]
+    ratio = w / float(h)
+    if math.ceil(imgH * ratio) > imgW:
+        resized_w = imgW
+    else:
+        resized_w = int(math.ceil(imgH * ratio))
+    resized_image = cv2.resize(img, (resized_w, imgH))
+    resized_image = resized_image.astype('float32')
+    if image_shape[0] == 1:
+        resized_image = resized_image / 255
+        resized_image = resized_image[np.newaxis, :]
+    else:
+        resized_image = resized_image.transpose((2, 0, 1)) / 255
+    resized_image -= 0.5
+    resized_image /= 0.5
+    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+    padding_im[:, :, 0:resized_w] = resized_image
+    return padding_im
+
+
+def resize_norm_img_chinese(img, image_shape):
+    imgC, imgH, imgW = image_shape
+    # todo: change to 0 and modified image shape
+    max_wh_ratio = imgW * 1.0 / imgH
+    h, w = img.shape[0], img.shape[1]
+    ratio = w * 1.0 / h
+    max_wh_ratio = max(max_wh_ratio, ratio)
+    imgW = int(32 * max_wh_ratio)
+    if math.ceil(imgH * ratio) > imgW:
+        resized_w = imgW
+    else:
+        resized_w = int(math.ceil(imgH * ratio))
+    resized_image = cv2.resize(img, (resized_w, imgH))
+    resized_image = resized_image.astype('float32')
+    if image_shape[0] == 1:
+        resized_image = resized_image / 255
+        resized_image = resized_image[np.newaxis, :]
+    else:
+        resized_image = resized_image.transpose((2, 0, 1)) / 255
+    resized_image -= 0.5
+    resized_image /= 0.5
+    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+    padding_im[:, :, 0:resized_w] = resized_image
+    return padding_im
+
+
+def resize_norm_img_srn(img, image_shape):
+    imgC, imgH, imgW = image_shape
+
+    img_black = np.zeros((imgH, imgW))
+    im_hei = img.shape[0]
+    im_wid = img.shape[1]
+
+    if im_wid <= im_hei * 1:
+        img_new = cv2.resize(img, (imgH * 1, imgH))
+    elif im_wid <= im_hei * 2:
+        img_new = cv2.resize(img, (imgH * 2, imgH))
+    elif im_wid <= im_hei * 3:
+        img_new = cv2.resize(img, (imgH * 3, imgH))
+    else:
+        img_new = cv2.resize(img, (imgW, imgH))
+
+    img_np = np.asarray(img_new)
+    img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+    img_black[:, 0:img_np.shape[1]] = img_np
+    img_black = img_black[:, :, np.newaxis]
+
+    row, col, c = img_black.shape
+    c = 1
+
+    return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+
+def srn_other_inputs(image_shape, num_heads, max_text_length):
+
+    imgC, imgH, imgW = image_shape
+    feature_dim = int((imgH / 8) * (imgW / 8))
+
+    encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+        (feature_dim, 1)).astype('int64')
+    gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+        (max_text_length, 1)).astype('int64')
+
+    gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+    gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+        [1, max_text_length, max_text_length])
+    gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
+                                  [num_heads, 1, 1]) * [-1e9]
+
+    gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+        [1, max_text_length, max_text_length])
+    gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
+                                  [num_heads, 1, 1]) * [-1e9]
+
+    return [
+        encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+        gsrm_slf_attn_bias2
+    ]
+
+
+def flag():
+    """
+    flag
+    """
+    return 1 if random.random() > 0.5000001 else -1
+
+
+def cvtColor(img):
+    """
+    cvtColor
+    """
+    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+    delta = 0.001 * random.random() * flag()
+    hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
+    new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+    return new_img
+
+
+def blur(img):
+    """
+    blur
+    """
+    h, w, _ = img.shape
+    if h > 10 and w > 10:
+        return cv2.GaussianBlur(img, (5, 5), 1)
+    else:
+        return img
+
+
+def jitter(img):
+    """
+    jitter
+    """
+    w, h, _ = img.shape
+    if h > 10 and w > 10:
+        thres = min(w, h)
+        s = int(random.random() * thres * 0.01)
+        src_img = img.copy()
+        for i in range(s):
+            img[i:, i:, :] = src_img[:w - i, :h - i, :]
+        return img
+    else:
+        return img
+
+
+def add_gasuss_noise(image, mean=0, var=0.1):
+    """
+    Gasuss noise
+    """
+
+    noise = np.random.normal(mean, var**0.5, image.shape)
+    out = image + 0.5 * noise
+    out = np.clip(out, 0, 255)
+    out = np.uint8(out)
+    return out
+
+
+def get_crop(image):
+    """
+    random crop
+    """
+    h, w, _ = image.shape
+    top_min = 1
+    top_max = 8
+    top_crop = int(random.randint(top_min, top_max))
+    top_crop = min(top_crop, h - 1)
+    crop_img = image.copy()
+    ratio = random.randint(0, 1)
+    if ratio:
+        crop_img = crop_img[top_crop:h, :, :]
+    else:
+        crop_img = crop_img[0:h - top_crop, :, :]
+    return crop_img
+
+
+class Config:
+    """
+    Config
+    """
+
+    def __init__(self, use_tia):
+        self.anglex = random.random() * 30
+        self.angley = random.random() * 15
+        self.anglez = random.random() * 10
+        self.fov = 42
+        self.r = 0
+        self.shearx = random.random() * 0.3
+        self.sheary = random.random() * 0.05
+        self.borderMode = cv2.BORDER_REPLICATE
+        self.use_tia = use_tia
+
+    def make(self, w, h, ang):
+        """
+        make
+        """
+        self.anglex = random.random() * 5 * flag()
+        self.angley = random.random() * 5 * flag()
+        self.anglez = -1 * random.random() * int(ang) * flag()
+        self.fov = 42
+        self.r = 0
+        self.shearx = 0
+        self.sheary = 0
+        self.borderMode = cv2.BORDER_REPLICATE
+        self.w = w
+        self.h = h
+
+        self.perspective = self.use_tia
+        self.stretch = self.use_tia
+        self.distort = self.use_tia
+
+        self.crop = True
+        self.affine = False
+        self.reverse = True
+        self.noise = True
+        self.jitter = True
+        self.blur = True
+        self.color = True
+
+
+def rad(x):
+    """
+    rad
+    """
+    return x * np.pi / 180
+
+
+def get_warpR(config):
+    """
+    get_warpR
+    """
+    anglex, angley, anglez, fov, w, h, r = \
+        config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
+    if w > 69 and w < 112:
+        anglex = anglex * 1.5
+
+    z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
+    # Homogeneous coordinate transformation matrix
+    rx = np.array([[1, 0, 0, 0],
+                   [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
+                       0,
+                       -np.sin(rad(anglex)),
+                       np.cos(rad(anglex)),
+                       0,
+                   ], [0, 0, 0, 1]], np.float32)
+    ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
+                   [0, 1, 0, 0], [
+                       -np.sin(rad(angley)),
+                       0,
+                       np.cos(rad(angley)),
+                       0,
+                   ], [0, 0, 0, 1]], np.float32)
+    rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
+                   [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
+                   [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
+    r = rx.dot(ry).dot(rz)
+    # generate 4 points
+    pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
+    p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
+    p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
+    p3 = np.array([0, h, 0, 0], np.float32) - pcenter
+    p4 = np.array([w, h, 0, 0], np.float32) - pcenter
+    dst1 = r.dot(p1)
+    dst2 = r.dot(p2)
+    dst3 = r.dot(p3)
+    dst4 = r.dot(p4)
+    list_dst = np.array([dst1, dst2, dst3, dst4])
+    org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
+    dst = np.zeros((4, 2), np.float32)
+    # Project onto the image plane
+    dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
+    dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
+
+    warpR = cv2.getPerspectiveTransform(org, dst)
+
+    dst1, dst2, dst3, dst4 = dst
+    r1 = int(min(dst1[1], dst2[1]))
+    r2 = int(max(dst3[1], dst4[1]))
+    c1 = int(min(dst1[0], dst3[0]))
+    c2 = int(max(dst2[0], dst4[0]))
+
+    try:
+        ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
+
+        dx = -c1
+        dy = -r1
+        T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
+        ret = T1.dot(warpR)
+    except:
+        ratio = 1.0
+        T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
+        ret = T1
+    return ret, (-r1, -c1), ratio, dst
+
+
+def get_warpAffine(config):
+    """
+    get_warpAffine
+    """
+    anglez = config.anglez
+    rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
+                   [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
+    return rz
+
+
+def warp(img, ang, use_tia=True, prob=0.4):
+    """
+    warp
+    """
+    h, w, _ = img.shape
+    config = Config(use_tia=use_tia)
+    config.make(w, h, ang)
+    new_img = img
+
+    if config.distort:
+        img_height, img_width = img.shape[0:2]
+        if random.random() <= prob and img_height >= 20 and img_width >= 20:
+            new_img = tia_distort(new_img, random.randint(3, 6))
+
+    if config.stretch:
+        img_height, img_width = img.shape[0:2]
+        if random.random() <= prob and img_height >= 20 and img_width >= 20:
+            new_img = tia_stretch(new_img, random.randint(3, 6))
+
+    if config.perspective:
+        if random.random() <= prob:
+            new_img = tia_perspective(new_img)
+
+    if config.crop:
+        img_height, img_width = img.shape[0:2]
+        if random.random() <= prob and img_height >= 20 and img_width >= 20:
+            new_img = get_crop(new_img)
+
+    if config.blur:
+        if random.random() <= prob:
+            new_img = blur(new_img)
+    if config.color:
+        if random.random() <= prob:
+            new_img = cvtColor(new_img)
+    if config.jitter:
+        new_img = jitter(new_img)
+    if config.noise:
+        if random.random() <= prob:
+            new_img = add_gasuss_noise(new_img)
+    if config.reverse:
+        if random.random() <= prob:
+            new_img = 255 - new_img
+    return new_img

+ 774 - 0
ocr/ppocr/data/imaug/sast_process.py

@@ -0,0 +1,774 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import math
+import cv2
+import numpy as np
+import json
+import sys
+import os
+
+__all__ = ['SASTProcessTrain']
+
+
+class SASTProcessTrain(object):
+    def __init__(self,
+                 image_shape=[512, 512],
+                 min_crop_size=24,
+                 min_crop_side_ratio=0.3,
+                 min_text_size=10,
+                 max_text_size=512,
+                 **kwargs):
+        self.input_size = image_shape[1]
+        self.min_crop_size = min_crop_size
+        self.min_crop_side_ratio = min_crop_side_ratio
+        self.min_text_size = min_text_size
+        self.max_text_size = max_text_size
+
+    def quad_area(self, poly):
+        """
+        compute area of a polygon
+        :param poly:
+        :return:
+        """
+        edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+                (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+                (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+                (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
+        return np.sum(edge) / 2.
+
+    def gen_quad_from_poly(self, poly):
+        """
+        Generate min area quad from poly.
+        """
+        point_num = poly.shape[0]
+        min_area_quad = np.zeros((4, 2), dtype=np.float32)
+        if True:
+            rect = cv2.minAreaRect(poly.astype(
+                np.int32))  # (center (x,y), (width, height), angle of rotation)
+            center_point = rect[0]
+            box = np.array(cv2.boxPoints(rect))
+
+            first_point_idx = 0
+            min_dist = 1e4
+            for i in range(4):
+                dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
+                    np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
+                    np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
+                    np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+                if dist < min_dist:
+                    min_dist = dist
+                    first_point_idx = i
+            for i in range(4):
+                min_area_quad[i] = box[(first_point_idx + i) % 4]
+
+        return min_area_quad
+
+    def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
+        """
+        check so that the text poly is in the same direction,
+        and also filter some invalid polygons
+        :param polys:
+        :param tags:
+        :return:
+        """
+        (h, w) = xxx_todo_changeme
+        if polys.shape[0] == 0:
+            return polys, np.array([]), np.array([])
+        polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
+        polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
+
+        validated_polys = []
+        validated_tags = []
+        hv_tags = []
+        for poly, tag in zip(polys, tags):
+            quad = self.gen_quad_from_poly(poly)
+            p_area = self.quad_area(quad)
+            if abs(p_area) < 1:
+                print('invalid poly')
+                continue
+            if p_area > 0:
+                if tag == False:
+                    print('poly in wrong direction')
+                    tag = True  # reversed cases should be ignore
+                poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
+                             1), :]
+                quad = quad[(0, 3, 2, 1), :]
+
+            len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
+                                                                       quad[2])
+            len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
+                                                                       quad[2])
+            hv_tag = 1
+
+            if len_w * 2.0 < len_h:
+                hv_tag = 0
+
+            validated_polys.append(poly)
+            validated_tags.append(tag)
+            hv_tags.append(hv_tag)
+        return np.array(validated_polys), np.array(validated_tags), np.array(
+            hv_tags)
+
+    def crop_area(self,
+                  im,
+                  polys,
+                  tags,
+                  hv_tags,
+                  crop_background=False,
+                  max_tries=25):
+        """
+        make random crop from the input image
+        :param im:
+        :param polys:
+        :param tags:
+        :param crop_background:
+        :param max_tries: 50 -> 25
+        :return:
+        """
+        h, w, _ = im.shape
+        pad_h = h // 10
+        pad_w = w // 10
+        h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+        w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+        for poly in polys:
+            poly = np.round(poly, decimals=0).astype(np.int32)
+            minx = np.min(poly[:, 0])
+            maxx = np.max(poly[:, 0])
+            w_array[minx + pad_w:maxx + pad_w] = 1
+            miny = np.min(poly[:, 1])
+            maxy = np.max(poly[:, 1])
+            h_array[miny + pad_h:maxy + pad_h] = 1
+        # ensure the cropped area not across a text
+        h_axis = np.where(h_array == 0)[0]
+        w_axis = np.where(w_array == 0)[0]
+        if len(h_axis) == 0 or len(w_axis) == 0:
+            return im, polys, tags, hv_tags
+        for i in range(max_tries):
+            xx = np.random.choice(w_axis, size=2)
+            xmin = np.min(xx) - pad_w
+            xmax = np.max(xx) - pad_w
+            xmin = np.clip(xmin, 0, w - 1)
+            xmax = np.clip(xmax, 0, w - 1)
+            yy = np.random.choice(h_axis, size=2)
+            ymin = np.min(yy) - pad_h
+            ymax = np.max(yy) - pad_h
+            ymin = np.clip(ymin, 0, h - 1)
+            ymax = np.clip(ymax, 0, h - 1)
+            # if xmax - xmin < ARGS.min_crop_side_ratio * w or \
+            #   ymax - ymin < ARGS.min_crop_side_ratio * h:
+            if xmax - xmin < self.min_crop_size or \
+            ymax - ymin < self.min_crop_size:
+                # area too small
+                continue
+            if polys.shape[0] != 0:
+                poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
+                                    & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
+                selected_polys = np.where(
+                    np.sum(poly_axis_in_area, axis=1) == 4)[0]
+            else:
+                selected_polys = []
+            if len(selected_polys) == 0:
+                # no text in this area
+                if crop_background:
+                    return im[ymin : ymax + 1, xmin : xmax + 1, :], \
+                        polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
+                else:
+                    continue
+            im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+            polys = polys[selected_polys]
+            tags = tags[selected_polys]
+            hv_tags = hv_tags[selected_polys]
+            polys[:, :, 0] -= xmin
+            polys[:, :, 1] -= ymin
+            return im, polys, tags, hv_tags
+
+        return im, polys, tags, hv_tags
+
+    def generate_direction_map(self, poly_quads, direction_map):
+        """
+        """
+        width_list = []
+        height_list = []
+        for quad in poly_quads:
+            quad_w = (np.linalg.norm(quad[0] - quad[1]) +
+                      np.linalg.norm(quad[2] - quad[3])) / 2.0
+            quad_h = (np.linalg.norm(quad[0] - quad[3]) +
+                      np.linalg.norm(quad[2] - quad[1])) / 2.0
+            width_list.append(quad_w)
+            height_list.append(quad_h)
+        norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
+        average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
+
+        for quad in poly_quads:
+            direct_vector_full = (
+                (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
+            direct_vector = direct_vector_full / (
+                np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
+            direction_label = tuple(
+                map(float, [
+                    direct_vector[0], direct_vector[1], 1.0 / (average_height +
+                                                               1e-6)
+                ]))
+            cv2.fillPoly(direction_map,
+                         quad.round().astype(np.int32)[np.newaxis, :, :],
+                         direction_label)
+        return direction_map
+
+    def calculate_average_height(self, poly_quads):
+        """
+        """
+        height_list = []
+        for quad in poly_quads:
+            quad_h = (np.linalg.norm(quad[0] - quad[3]) +
+                      np.linalg.norm(quad[2] - quad[1])) / 2.0
+            height_list.append(quad_h)
+        average_height = max(sum(height_list) / len(height_list), 1.0)
+        return average_height
+
+    def generate_tcl_label(self,
+                           hw,
+                           polys,
+                           tags,
+                           ds_ratio,
+                           tcl_ratio=0.3,
+                           shrink_ratio_of_width=0.15):
+        """
+        Generate polygon.
+        """
+        h, w = hw
+        h, w = int(h * ds_ratio), int(w * ds_ratio)
+        polys = polys * ds_ratio
+
+        score_map = np.zeros(
+            (
+                h,
+                w, ), dtype=np.float32)
+        tbo_map = np.zeros((h, w, 5), dtype=np.float32)
+        training_mask = np.ones(
+            (
+                h,
+                w, ), dtype=np.float32)
+        direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
+            [1, 1, 3]).astype(np.float32)
+
+        for poly_idx, poly_tag in enumerate(zip(polys, tags)):
+            poly = poly_tag[0]
+            tag = poly_tag[1]
+
+            # generate min_area_quad
+            min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
+            min_area_quad_h = 0.5 * (
+                np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
+                np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+            min_area_quad_w = 0.5 * (
+                np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
+                np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
+
+            if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
+                or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
+                continue
+
+            if tag:
+                # continue
+                cv2.fillPoly(training_mask,
+                             poly.astype(np.int32)[np.newaxis, :, :], 0.15)
+            else:
+                tcl_poly = self.poly2tcl(poly, tcl_ratio)
+                tcl_quads = self.poly2quads(tcl_poly)
+                poly_quads = self.poly2quads(poly)
+                # stcl map
+                stcl_quads, quad_index = self.shrink_poly_along_width(
+                    tcl_quads,
+                    shrink_ratio_of_width=shrink_ratio_of_width,
+                    expand_height_ratio=1.0 / tcl_ratio)
+                # generate tcl map
+                cv2.fillPoly(score_map,
+                             np.round(stcl_quads).astype(np.int32), 1.0)
+
+                # generate tbo map
+                for idx, quad in enumerate(stcl_quads):
+                    quad_mask = np.zeros((h, w), dtype=np.float32)
+                    quad_mask = cv2.fillPoly(
+                        quad_mask,
+                        np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
+                    tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
+                                                quad_mask, tbo_map)
+        return score_map, tbo_map, training_mask
+
+    def generate_tvo_and_tco(self,
+                             hw,
+                             polys,
+                             tags,
+                             tcl_ratio=0.3,
+                             ds_ratio=0.25):
+        """
+        Generate tcl map, tvo map and tbo map.
+        """
+        h, w = hw
+        h, w = int(h * ds_ratio), int(w * ds_ratio)
+        polys = polys * ds_ratio
+        poly_mask = np.zeros((h, w), dtype=np.float32)
+
+        tvo_map = np.ones((9, h, w), dtype=np.float32)
+        tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
+        tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
+        poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
+
+        # tco map
+        tco_map = np.ones((3, h, w), dtype=np.float32)
+        tco_map[0] = np.tile(np.arange(0, w), (h, 1))
+        tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
+        poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
+
+        poly_short_edge_map = np.ones((h, w), dtype=np.float32)
+
+        for poly, poly_tag in zip(polys, tags):
+
+            if poly_tag == True:
+                continue
+
+            # adjust point order for vertical poly
+            poly = self.adjust_point(poly)
+
+            # generate min_area_quad
+            min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
+            min_area_quad_h = 0.5 * (
+                np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
+                np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+            min_area_quad_w = 0.5 * (
+                np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
+                np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
+
+            # generate tcl map and text, 128 * 128
+            tcl_poly = self.poly2tcl(poly, tcl_ratio)
+
+            # generate poly_tv_xy_map
+            for idx in range(4):
+                cv2.fillPoly(
+                    poly_tv_xy_map[2 * idx],
+                    np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
+                    float(min(max(min_area_quad[idx, 0], 0), w)))
+                cv2.fillPoly(
+                    poly_tv_xy_map[2 * idx + 1],
+                    np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
+                    float(min(max(min_area_quad[idx, 1], 0), h)))
+
+            # generate poly_tc_xy_map
+            for idx in range(2):
+                cv2.fillPoly(
+                    poly_tc_xy_map[idx],
+                    np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
+                    float(center_point[idx]))
+
+            # generate poly_short_edge_map
+            cv2.fillPoly(
+                poly_short_edge_map,
+                np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
+                float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
+
+            # generate poly_mask and training_mask
+            cv2.fillPoly(poly_mask,
+                         np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
+                         1)
+
+        tvo_map *= poly_mask
+        tvo_map[:8] -= poly_tv_xy_map
+        tvo_map[-1] /= poly_short_edge_map
+        tvo_map = tvo_map.transpose((1, 2, 0))
+
+        tco_map *= poly_mask
+        tco_map[:2] -= poly_tc_xy_map
+        tco_map[-1] /= poly_short_edge_map
+        tco_map = tco_map.transpose((1, 2, 0))
+
+        return tvo_map, tco_map
+
+    def adjust_point(self, poly):
+        """
+        adjust point order.
+        """
+        point_num = poly.shape[0]
+        if point_num == 4:
+            len_1 = np.linalg.norm(poly[0] - poly[1])
+            len_2 = np.linalg.norm(poly[1] - poly[2])
+            len_3 = np.linalg.norm(poly[2] - poly[3])
+            len_4 = np.linalg.norm(poly[3] - poly[0])
+
+            if (len_1 + len_3) * 1.5 < (len_2 + len_4):
+                poly = poly[[1, 2, 3, 0], :]
+
+        elif point_num > 4:
+            vector_1 = poly[0] - poly[1]
+            vector_2 = poly[1] - poly[2]
+            cos_theta = np.dot(vector_1, vector_2) / (
+                np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
+            theta = np.arccos(np.round(cos_theta, decimals=4))
+
+            if abs(theta) > (70 / 180 * math.pi):
+                index = list(range(1, point_num)) + [0]
+                poly = poly[np.array(index), :]
+        return poly
+
+    def gen_min_area_quad_from_poly(self, poly):
+        """
+        Generate min area quad from poly.
+        """
+        point_num = poly.shape[0]
+        min_area_quad = np.zeros((4, 2), dtype=np.float32)
+        if point_num == 4:
+            min_area_quad = poly
+            center_point = np.sum(poly, axis=0) / 4
+        else:
+            rect = cv2.minAreaRect(poly.astype(
+                np.int32))  # (center (x,y), (width, height), angle of rotation)
+            center_point = rect[0]
+            box = np.array(cv2.boxPoints(rect))
+
+            first_point_idx = 0
+            min_dist = 1e4
+            for i in range(4):
+                dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
+                    np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
+                    np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
+                    np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+                if dist < min_dist:
+                    min_dist = dist
+                    first_point_idx = i
+
+            for i in range(4):
+                min_area_quad[i] = box[(first_point_idx + i) % 4]
+
+        return min_area_quad, center_point
+
+    def shrink_quad_along_width(self,
+                                quad,
+                                begin_width_ratio=0.,
+                                end_width_ratio=1.):
+        """
+        Generate shrink_quad_along_width.
+        """
+        ratio_pair = np.array(
+            [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+        p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+        p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+        return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+    def shrink_poly_along_width(self,
+                                quads,
+                                shrink_ratio_of_width,
+                                expand_height_ratio=1.0):
+        """
+        shrink poly with given length.
+        """
+        upper_edge_list = []
+
+        def get_cut_info(edge_len_list, cut_len):
+            for idx, edge_len in enumerate(edge_len_list):
+                cut_len -= edge_len
+                if cut_len <= 0.000001:
+                    ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
+                    return idx, ratio
+
+        for quad in quads:
+            upper_edge_len = np.linalg.norm(quad[0] - quad[1])
+            upper_edge_list.append(upper_edge_len)
+
+        # length of left edge and right edge.
+        left_length = np.linalg.norm(quads[0][0] - quads[0][
+            3]) * expand_height_ratio
+        right_length = np.linalg.norm(quads[-1][1] - quads[-1][
+            2]) * expand_height_ratio
+
+        shrink_length = min(left_length, right_length,
+                            sum(upper_edge_list)) * shrink_ratio_of_width
+        # shrinking length
+        upper_len_left = shrink_length
+        upper_len_right = sum(upper_edge_list) - shrink_length
+
+        left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
+        left_quad = self.shrink_quad_along_width(
+            quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
+        right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
+        right_quad = self.shrink_quad_along_width(
+            quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
+
+        out_quad_list = []
+        if left_idx == right_idx:
+            out_quad_list.append(
+                [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
+        else:
+            out_quad_list.append(left_quad)
+            for idx in range(left_idx + 1, right_idx):
+                out_quad_list.append(quads[idx])
+            out_quad_list.append(right_quad)
+
+        return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
+
+    def vector_angle(self, A, B):
+        """
+        Calculate the angle between vector AB and x-axis positive direction.
+        """
+        AB = np.array([B[1] - A[1], B[0] - A[0]])
+        return np.arctan2(*AB)
+
+    def theta_line_cross_point(self, theta, point):
+        """
+        Calculate the line through given point and angle in ax + by + c =0 form.
+        """
+        x, y = point
+        cos = np.cos(theta)
+        sin = np.sin(theta)
+        return [sin, -cos, cos * y - sin * x]
+
+    def line_cross_two_point(self, A, B):
+        """
+        Calculate the line through given point A and B in ax + by + c =0 form.
+        """
+        angle = self.vector_angle(A, B)
+        return self.theta_line_cross_point(angle, A)
+
+    def average_angle(self, poly):
+        """
+        Calculate the average angle between left and right edge in given poly.
+        """
+        p0, p1, p2, p3 = poly
+        angle30 = self.vector_angle(p3, p0)
+        angle21 = self.vector_angle(p2, p1)
+        return (angle30 + angle21) / 2
+
+    def line_cross_point(self, line1, line2):
+        """
+        line1 and line2 in  0=ax+by+c form, compute the cross point of line1 and line2
+        """
+        a1, b1, c1 = line1
+        a2, b2, c2 = line2
+        d = a1 * b2 - a2 * b1
+
+        if d == 0:
+            #print("line1", line1)
+            #print("line2", line2)
+            print('Cross point does not exist')
+            return np.array([0, 0], dtype=np.float32)
+        else:
+            x = (b1 * c2 - b2 * c1) / d
+            y = (a2 * c1 - a1 * c2) / d
+
+        return np.array([x, y], dtype=np.float32)
+
+    def quad2tcl(self, poly, ratio):
+        """
+        Generate center line by poly clock-wise point. (4, 2)
+        """
+        ratio_pair = np.array(
+            [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+        p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
+        p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
+        return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
+
+    def poly2tcl(self, poly, ratio):
+        """
+        Generate center line by poly clock-wise point.
+        """
+        ratio_pair = np.array(
+            [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+        tcl_poly = np.zeros_like(poly)
+        point_num = poly.shape[0]
+
+        for idx in range(point_num // 2):
+            point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
+                                      ) * ratio_pair
+            tcl_poly[idx] = point_pair[0]
+            tcl_poly[point_num - 1 - idx] = point_pair[1]
+        return tcl_poly
+
+    def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
+        """
+        Generate tbo_map for give quad.
+        """
+        # upper and lower line function: ax + by + c = 0;
+        up_line = self.line_cross_two_point(quad[0], quad[1])
+        lower_line = self.line_cross_two_point(quad[3], quad[2])
+
+        quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
+                        np.linalg.norm(quad[1] - quad[2]))
+        quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
+                        np.linalg.norm(quad[2] - quad[3]))
+
+        # average angle of left and right line.
+        angle = self.average_angle(quad)
+
+        xy_in_poly = np.argwhere(tcl_mask == 1)
+        for y, x in xy_in_poly:
+            point = (x, y)
+            line = self.theta_line_cross_point(angle, point)
+            cross_point_upper = self.line_cross_point(up_line, line)
+            cross_point_lower = self.line_cross_point(lower_line, line)
+            ##FIX, offset reverse
+            upper_offset_x, upper_offset_y = cross_point_upper - point
+            lower_offset_x, lower_offset_y = cross_point_lower - point
+            tbo_map[y, x, 0] = upper_offset_y
+            tbo_map[y, x, 1] = upper_offset_x
+            tbo_map[y, x, 2] = lower_offset_y
+            tbo_map[y, x, 3] = lower_offset_x
+            tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
+        return tbo_map
+
+    def poly2quads(self, poly):
+        """
+        Split poly into quads.
+        """
+        quad_list = []
+        point_num = poly.shape[0]
+
+        # point pair
+        point_pair_list = []
+        for idx in range(point_num // 2):
+            point_pair = [poly[idx], poly[point_num - 1 - idx]]
+            point_pair_list.append(point_pair)
+
+        quad_num = point_num // 2 - 1
+        for idx in range(quad_num):
+            # reshape and adjust to clock-wise
+            quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
+                              ).reshape(4, 2)[[0, 2, 3, 1]])
+
+        return np.array(quad_list)
+
+    def __call__(self, data):
+        im = data['image']
+        text_polys = data['polys']
+        text_tags = data['ignore_tags']
+        if im is None:
+            return None
+        if text_polys.shape[0] == 0:
+            return None
+
+        h, w, _ = im.shape
+        text_polys, text_tags, hv_tags = self.check_and_validate_polys(
+            text_polys, text_tags, (h, w))
+
+        if text_polys.shape[0] == 0:
+            return None
+
+        #set aspect ratio and keep area fix
+        asp_scales = np.arange(1.0, 1.55, 0.1)
+        asp_scale = np.random.choice(asp_scales)
+
+        if np.random.rand() < 0.5:
+            asp_scale = 1.0 / asp_scale
+        asp_scale = math.sqrt(asp_scale)
+
+        asp_wx = asp_scale
+        asp_hy = 1.0 / asp_scale
+        im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
+        text_polys[:, :, 0] *= asp_wx
+        text_polys[:, :, 1] *= asp_hy
+
+        h, w, _ = im.shape
+        if max(h, w) > 2048:
+            rd_scale = 2048.0 / max(h, w)
+            im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+            text_polys *= rd_scale
+        h, w, _ = im.shape
+        if min(h, w) < 16:
+            return None
+
+        #no background
+        im, text_polys, text_tags, hv_tags = self.crop_area(im, \
+            text_polys, text_tags, hv_tags, crop_background=False)
+
+        if text_polys.shape[0] == 0:
+            return None
+        #continue for all ignore case
+        if np.sum((text_tags * 1.0)) >= text_tags.size:
+            return None
+        new_h, new_w, _ = im.shape
+        if (new_h is None) or (new_w is None):
+            return None
+        #resize image
+        std_ratio = float(self.input_size) / max(new_w, new_h)
+        rand_scales = np.array(
+            [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
+        rz_scale = std_ratio * np.random.choice(rand_scales)
+        im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
+        text_polys[:, :, 0] *= rz_scale
+        text_polys[:, :, 1] *= rz_scale
+
+        #add gaussian blur
+        if np.random.rand() < 0.1 * 0.5:
+            ks = np.random.permutation(5)[0] + 1
+            ks = int(ks / 2) * 2 + 1
+            im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
+        #add brighter
+        if np.random.rand() < 0.1 * 0.5:
+            im = im * (1.0 + np.random.rand() * 0.5)
+            im = np.clip(im, 0.0, 255.0)
+        #add darker
+        if np.random.rand() < 0.1 * 0.5:
+            im = im * (1.0 - np.random.rand() * 0.5)
+            im = np.clip(im, 0.0, 255.0)
+
+        # Padding the im to [input_size, input_size]
+        new_h, new_w, _ = im.shape
+        if min(new_w, new_h) < self.input_size * 0.5:
+            return None
+
+        im_padded = np.ones(
+            (self.input_size, self.input_size, 3), dtype=np.float32)
+        im_padded[:, :, 2] = 0.485 * 255
+        im_padded[:, :, 1] = 0.456 * 255
+        im_padded[:, :, 0] = 0.406 * 255
+
+        # Random the start position
+        del_h = self.input_size - new_h
+        del_w = self.input_size - new_w
+        sh, sw = 0, 0
+        if del_h > 1:
+            sh = int(np.random.rand() * del_h)
+        if del_w > 1:
+            sw = int(np.random.rand() * del_w)
+
+        # Padding
+        im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
+        text_polys[:, :, 0] += sw
+        text_polys[:, :, 1] += sh
+
+        score_map, border_map, training_mask = self.generate_tcl_label(
+            (self.input_size, self.input_size), text_polys, text_tags, 0.25)
+
+        # SAST head
+        tvo_map, tco_map = self.generate_tvo_and_tco(
+            (self.input_size, self.input_size),
+            text_polys,
+            text_tags,
+            tcl_ratio=0.3,
+            ds_ratio=0.25)
+        # print("test--------tvo_map shape:", tvo_map.shape)
+
+        im_padded[:, :, 2] -= 0.485 * 255
+        im_padded[:, :, 1] -= 0.456 * 255
+        im_padded[:, :, 0] -= 0.406 * 255
+        im_padded[:, :, 2] /= (255.0 * 0.229)
+        im_padded[:, :, 1] /= (255.0 * 0.224)
+        im_padded[:, :, 0] /= (255.0 * 0.225)
+        im_padded = im_padded.transpose((2, 0, 1))
+
+        data['image'] = im_padded[::-1, :, :]
+        data['score_map'] = score_map[np.newaxis, :, :]
+        data['border_map'] = border_map.transpose((2, 0, 1))
+        data['training_mask'] = training_mask[np.newaxis, :, :]
+        data['tvo_map'] = tvo_map.transpose((2, 0, 1))
+        data['tco_map'] = tco_map.transpose((2, 0, 1))
+        return data

+ 17 - 0
ocr/ppocr/data/imaug/text_image_aug/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .augment import tia_perspective, tia_distort, tia_stretch
+
+__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']

+ 116 - 0
ocr/ppocr/data/imaug/text_image_aug/augment.py

@@ -0,0 +1,116 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from .warp_mls import WarpMLS
+
+
+def tia_distort(src, segment=4):
+    img_h, img_w = src.shape[:2]
+
+    cut = img_w // segment
+    thresh = cut // 3
+
+    src_pts = list()
+    dst_pts = list()
+
+    src_pts.append([0, 0])
+    src_pts.append([img_w, 0])
+    src_pts.append([img_w, img_h])
+    src_pts.append([0, img_h])
+
+    dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
+    dst_pts.append(
+        [img_w - np.random.randint(thresh), np.random.randint(thresh)])
+    dst_pts.append(
+        [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
+    dst_pts.append(
+        [np.random.randint(thresh), img_h - np.random.randint(thresh)])
+
+    half_thresh = thresh * 0.5
+
+    for cut_idx in np.arange(1, segment, 1):
+        src_pts.append([cut * cut_idx, 0])
+        src_pts.append([cut * cut_idx, img_h])
+        dst_pts.append([
+            cut * cut_idx + np.random.randint(thresh) - half_thresh,
+            np.random.randint(thresh) - half_thresh
+        ])
+        dst_pts.append([
+            cut * cut_idx + np.random.randint(thresh) - half_thresh,
+            img_h + np.random.randint(thresh) - half_thresh
+        ])
+
+    trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
+    dst = trans.generate()
+
+    return dst
+
+
+def tia_stretch(src, segment=4):
+    img_h, img_w = src.shape[:2]
+
+    cut = img_w // segment
+    thresh = cut * 4 // 5
+
+    src_pts = list()
+    dst_pts = list()
+
+    src_pts.append([0, 0])
+    src_pts.append([img_w, 0])
+    src_pts.append([img_w, img_h])
+    src_pts.append([0, img_h])
+
+    dst_pts.append([0, 0])
+    dst_pts.append([img_w, 0])
+    dst_pts.append([img_w, img_h])
+    dst_pts.append([0, img_h])
+
+    half_thresh = thresh * 0.5
+
+    for cut_idx in np.arange(1, segment, 1):
+        move = np.random.randint(thresh) - half_thresh
+        src_pts.append([cut * cut_idx, 0])
+        src_pts.append([cut * cut_idx, img_h])
+        dst_pts.append([cut * cut_idx + move, 0])
+        dst_pts.append([cut * cut_idx + move, img_h])
+
+    trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
+    dst = trans.generate()
+
+    return dst
+
+
+def tia_perspective(src):
+    img_h, img_w = src.shape[:2]
+
+    thresh = img_h // 2
+
+    src_pts = list()
+    dst_pts = list()
+
+    src_pts.append([0, 0])
+    src_pts.append([img_w, 0])
+    src_pts.append([img_w, img_h])
+    src_pts.append([0, img_h])
+
+    dst_pts.append([0, np.random.randint(thresh)])
+    dst_pts.append([img_w, np.random.randint(thresh)])
+    dst_pts.append([img_w, img_h - np.random.randint(thresh)])
+    dst_pts.append([0, img_h - np.random.randint(thresh)])
+
+    trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
+    dst = trans.generate()
+
+    return dst

+ 164 - 0
ocr/ppocr/data/imaug/text_image_aug/warp_mls.py

@@ -0,0 +1,164 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+class WarpMLS:
+    def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
+        self.src = src
+        self.src_pts = src_pts
+        self.dst_pts = dst_pts
+        self.pt_count = len(self.dst_pts)
+        self.dst_w = dst_w
+        self.dst_h = dst_h
+        self.trans_ratio = trans_ratio
+        self.grid_size = 100
+        self.rdx = np.zeros((self.dst_h, self.dst_w))
+        self.rdy = np.zeros((self.dst_h, self.dst_w))
+
+    @staticmethod
+    def __bilinear_interp(x, y, v11, v12, v21, v22):
+        return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
+                                                      (1 - y) + v22 * y) * x
+
+    def generate(self):
+        self.calc_delta()
+        return self.gen_img()
+
+    def calc_delta(self):
+        w = np.zeros(self.pt_count, dtype=np.float32)
+
+        if self.pt_count < 2:
+            return
+
+        i = 0
+        while 1:
+            if self.dst_w <= i < self.dst_w + self.grid_size - 1:
+                i = self.dst_w - 1
+            elif i >= self.dst_w:
+                break
+
+            j = 0
+            while 1:
+                if self.dst_h <= j < self.dst_h + self.grid_size - 1:
+                    j = self.dst_h - 1
+                elif j >= self.dst_h:
+                    break
+
+                sw = 0
+                swp = np.zeros(2, dtype=np.float32)
+                swq = np.zeros(2, dtype=np.float32)
+                new_pt = np.zeros(2, dtype=np.float32)
+                cur_pt = np.array([i, j], dtype=np.float32)
+
+                k = 0
+                for k in range(self.pt_count):
+                    if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
+                        break
+
+                    w[k] = 1. / (
+                        (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
+                        (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
+
+                    sw += w[k]
+                    swp = swp + w[k] * np.array(self.dst_pts[k])
+                    swq = swq + w[k] * np.array(self.src_pts[k])
+
+                if k == self.pt_count - 1:
+                    pstar = 1 / sw * swp
+                    qstar = 1 / sw * swq
+
+                    miu_s = 0
+                    for k in range(self.pt_count):
+                        if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
+                            continue
+                        pt_i = self.dst_pts[k] - pstar
+                        miu_s += w[k] * np.sum(pt_i * pt_i)
+
+                    cur_pt -= pstar
+                    cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
+
+                    for k in range(self.pt_count):
+                        if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
+                            continue
+
+                        pt_i = self.dst_pts[k] - pstar
+                        pt_j = np.array([-pt_i[1], pt_i[0]])
+
+                        tmp_pt = np.zeros(2, dtype=np.float32)
+                        tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
+                                    np.sum(pt_j * cur_pt) * self.src_pts[k][1]
+                        tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
+                                    np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
+                        tmp_pt *= (w[k] / miu_s)
+                        new_pt += tmp_pt
+
+                    new_pt += qstar
+                else:
+                    new_pt = self.src_pts[k]
+
+                self.rdx[j, i] = new_pt[0] - i
+                self.rdy[j, i] = new_pt[1] - j
+
+                j += self.grid_size
+            i += self.grid_size
+
+    def gen_img(self):
+        src_h, src_w = self.src.shape[:2]
+        dst = np.zeros_like(self.src, dtype=np.float32)
+
+        for i in np.arange(0, self.dst_h, self.grid_size):
+            for j in np.arange(0, self.dst_w, self.grid_size):
+                ni = i + self.grid_size
+                nj = j + self.grid_size
+                w = h = self.grid_size
+                if ni >= self.dst_h:
+                    ni = self.dst_h - 1
+                    h = ni - i + 1
+                if nj >= self.dst_w:
+                    nj = self.dst_w - 1
+                    w = nj - j + 1
+
+                di = np.reshape(np.arange(h), (-1, 1))
+                dj = np.reshape(np.arange(w), (1, -1))
+                delta_x = self.__bilinear_interp(
+                    di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
+                    self.rdx[ni, j], self.rdx[ni, nj])
+                delta_y = self.__bilinear_interp(
+                    di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
+                    self.rdy[ni, j], self.rdy[ni, nj])
+                nx = j + dj + delta_x * self.trans_ratio
+                ny = i + di + delta_y * self.trans_ratio
+                nx = np.clip(nx, 0, src_w - 1)
+                ny = np.clip(ny, 0, src_h - 1)
+                nxi = np.array(np.floor(nx), dtype=np.int32)
+                nyi = np.array(np.floor(ny), dtype=np.int32)
+                nxi1 = np.array(np.ceil(nx), dtype=np.int32)
+                nyi1 = np.array(np.ceil(ny), dtype=np.int32)
+
+                if len(self.src.shape) == 3:
+                    x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
+                    y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
+                else:
+                    x = ny - nyi
+                    y = nx - nxi
+                dst[i:i + h, j:j + w] = self.__bilinear_interp(
+                    x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
+                    self.src[nyi1, nxi], self.src[nyi1, nxi1])
+
+        dst = np.clip(dst, 0, 255)
+        dst = np.array(dst, dtype=np.uint8)
+
+        return dst

+ 115 - 0
ocr/ppocr/data/lmdb_dataset.py

@@ -0,0 +1,115 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+from paddle.io import Dataset
+import lmdb
+import cv2
+
+from .imaug import transform, create_operators
+
+
+class LMDBDataSet(Dataset):
+    def __init__(self, config, mode, logger, seed=None):
+        super(LMDBDataSet, self).__init__()
+
+        global_config = config['Global']
+        dataset_config = config[mode]['dataset']
+        loader_config = config[mode]['loader']
+        batch_size = loader_config['batch_size_per_card']
+        data_dir = dataset_config['data_dir']
+        self.do_shuffle = loader_config['shuffle']
+
+        self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
+        logger.info("Initialize indexs of datasets:%s" % data_dir)
+        self.data_idx_order_list = self.dataset_traversal()
+        if self.do_shuffle:
+            np.random.shuffle(self.data_idx_order_list)
+        self.ops = create_operators(dataset_config['transforms'], global_config)
+
+    def load_hierarchical_lmdb_dataset(self, data_dir):
+        lmdb_sets = {}
+        dataset_idx = 0
+        for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
+            if not dirnames:
+                env = lmdb.open(
+                    dirpath,
+                    max_readers=32,
+                    readonly=True,
+                    lock=False,
+                    readahead=False,
+                    meminit=False)
+                txn = env.begin(write=False)
+                num_samples = int(txn.get('num-samples'.encode()))
+                lmdb_sets[dataset_idx] = {"dirpath": dirpath, "env": env,
+                                          "txn": txn, "num_samples": num_samples}
+                dataset_idx += 1
+        return lmdb_sets
+
+    def dataset_traversal(self):
+        lmdb_num = len(self.lmdb_sets)
+        total_sample_num = 0
+        for lno in range(lmdb_num):
+            total_sample_num += self.lmdb_sets[lno]['num_samples']
+        data_idx_order_list = np.zeros((total_sample_num, 2))
+        beg_idx = 0
+        for lno in range(lmdb_num):
+            tmp_sample_num = self.lmdb_sets[lno]['num_samples']
+            end_idx = beg_idx + tmp_sample_num
+            data_idx_order_list[beg_idx:end_idx, 0] = lno
+            data_idx_order_list[beg_idx:end_idx, 1] \
+                = list(range(tmp_sample_num))
+            data_idx_order_list[beg_idx:end_idx, 1] += 1
+            beg_idx = beg_idx + tmp_sample_num
+        return data_idx_order_list
+
+    def get_img_data(self, value):
+        """get_img_data"""
+        if not value:
+            return None
+        imgdata = np.frombuffer(value, dtype='uint8')
+        if imgdata is None:
+            return None
+        imgori = cv2.imdecode(imgdata, 1)
+        if imgori is None:
+            return None
+        return imgori
+
+    def get_lmdb_sample_info(self, txn, index):
+        label_key = 'label-%09d'.encode() % index
+        label = txn.get(label_key)
+        if label is None:
+            return None
+        label = label.decode('utf-8')
+        img_key = 'image-%09d'.encode() % index
+        imgbuf = txn.get(img_key)
+        return imgbuf, label
+
+    def __getitem__(self, idx):
+        lmdb_idx, file_idx = self.data_idx_order_list[idx]
+        lmdb_idx = int(lmdb_idx)
+        file_idx = int(file_idx)
+        sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
+                                                file_idx)
+        if sample_info is None:
+            return self.__getitem__(np.random.randint(self.__len__()))
+        img, label = sample_info
+        data = {'image': img, 'label': label}
+        outs = transform(data, self.ops)
+        if outs is None:
+            return self.__getitem__(np.random.randint(self.__len__()))
+        return outs
+
+    def __len__(self):
+        return self.data_idx_order_list.shape[0]

+ 126 - 0
ocr/ppocr/data/simple_dataset.py

@@ -0,0 +1,126 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+import random
+from paddle.io import Dataset
+from ocr.ppocr.data.text2Image import create_image, delete_image
+from .imaug import transform, create_operators
+import sys
+sys.setrecursionlimit(100000)
+
+
+class SimpleDataSet(Dataset):
+    def __init__(self, config, mode, logger, seed=None):
+        super(SimpleDataSet, self).__init__()
+        self.logger = logger
+
+        global_config = config['Global']
+        # 读取Train相关参数
+        dataset_config = config[mode]['dataset']
+        loader_config = config[mode]['loader']
+
+        self.delimiter = dataset_config.get('delimiter', '\t')
+
+        # 图片路径对应文字txt
+        label_file_list = dataset_config.pop('label_file_list')
+        data_source_num = len(label_file_list)
+        ratio_list = dataset_config.get("ratio_list", [1.0])
+        if isinstance(ratio_list, (float, int)):
+            ratio_list = [float(ratio_list)] * int(data_source_num)
+
+        assert len(
+            ratio_list
+        ) == data_source_num, "The length of ratio_list should be the same as the file_list."
+
+        # 图片路径
+        self.data_dir = dataset_config['data_dir']
+        self.do_shuffle = loader_config['shuffle']
+
+        self.seed = seed
+        logger.info("Initialize indexs of datasets:%s" % label_file_list)
+        self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
+        self.data_idx_order_list = list(range(len(self.data_lines)))
+        if mode.lower() == "train":
+            self.shuffle_data_random()
+        self.ops = create_operators(dataset_config['transforms'], global_config)
+
+    def get_image_info_list(self, file_list, ratio_list):
+        if isinstance(file_list, str):
+            file_list = [file_list]
+        data_lines = []
+        for idx, file in enumerate(file_list):
+            with open(file, "rb") as f:
+                lines = f.readlines()
+                random.seed(self.seed)
+                lines = random.sample(lines,
+                                      round(len(lines) * ratio_list[idx]))
+                data_lines.extend(lines)
+        return data_lines
+
+    def shuffle_data_random(self):
+        if self.do_shuffle:
+            random.seed(self.seed)
+            random.shuffle(self.data_lines)
+        return
+
+    def __getitem__(self, idx):
+        file_idx = self.data_idx_order_list[idx]
+        data_line = self.data_lines[file_idx]
+        try:
+            data_line = data_line.decode('utf-8')
+            substr = data_line.strip("\n").split(self.delimiter)
+
+            # 图片文件路径、图片文字标识
+            file_name = substr[0]
+            label = substr[1]
+
+            if file_name[:5] != "image":
+                # 临时按Label创建图片
+                create_image(self.data_dir, file_name, label)
+
+                # 读取图片
+                img_path = os.path.join(self.data_dir, file_name)
+                data = {'img_path': img_path, 'label': label}
+                if not os.path.exists(img_path):
+                    raise Exception("{} does not exist!".format(img_path))
+                with open(data['img_path'], 'rb') as f:
+                    img = f.read()
+                    data['image'] = img
+                outs = transform(data, self.ops)
+
+                # 删除临时图片文件
+                delete_image(self.data_dir, file_name)
+            else:
+                # 直接读取文件中有的图片
+                img_path = os.path.join(self.data_dir, file_name)
+                data = {'img_path': img_path, 'label': label}
+                if not os.path.exists(img_path):
+                    raise Exception("{} does not exist!".format(img_path))
+                with open(data['img_path'], 'rb') as f:
+                    img = f.read()
+                    data['image'] = img
+                outs = transform(data, self.ops)
+
+        except Exception as e:
+            self.logger.error(
+                "When parsing line {}, error happened with msg: {}".format(
+                    data_line, e))
+            outs = None
+        if outs is None:
+            return self.__getitem__(np.random.randint(self.__len__()))
+        return outs
+
+    def __len__(self):
+        return len(self.data_idx_order_list)

+ 455 - 0
ocr/ppocr/data/text2Image.py

@@ -0,0 +1,455 @@
+import random
+import re
+import numpy as np
+import cv2
+# import psycopg2
+from PIL import Image, ImageFont, ImageDraw
+import os
+from PIL import Image, ImageFont, ImageDraw
+import pandas as pd
+
+
+# project_path = "D:\\Project\\PaddleOCR-release-2.0\\"
+from bs4 import BeautifulSoup
+
+project_path = "../../"
+image_output_path = project_path + "train_data/bidi_data/mix_data3/"
+train_data_path = image_output_path + "rec_gt_train.txt"
+test_data_path = image_output_path + "rec_gt_test.txt"
+
+
+def create_image(data_dir, file_name, text):
+    list1 = re.findall('[a-zA-Z\d]', text)
+    list2 = re.findall('[\u4e00-\u9fa5。,!?¥《》【】’“:;·、()]', text)
+    list3 = re.findall('[,.!?&@*+=~%()#<>-''|/:{}$;]', text)
+    english_len = len(list1)
+    chinese_len = len(list2)
+    character_len = len(list3)
+
+    if english_len + chinese_len + character_len == 0:
+        character_len = len(text)
+
+    # 根据各字体大小生成图片
+    # font 10 : a1-6 字-10 image-len*, 16
+    # font 20 : a1-12 字-20 image-len*, 32
+    font_list = [10, 15, 20, 25, 30, 35, 40]
+    # 随机选字体大小
+    font_index = random.randint(0, len(font_list)-1)
+    font = font_list[font_index]
+
+    # 根据字体大小计算各字符长度
+    chinese_charc_len = font * 1
+    english_charc_len = int(font * 0.7)
+    number_charc_len = int(font * 0.3)
+    image_width = int(font * 1.6)
+    text_len = english_len * english_charc_len + chinese_len * chinese_charc_len \
+               + character_len * number_charc_len
+    im = Image.new("RGB", (text_len, image_width), (255, 255, 255))
+    dr = ImageDraw.Draw(im)
+    font = ImageFont.truetype("tools/fonts/msyh.ttc", font)
+    dr.text((0, 0), text, font=font, fill="#000000")
+
+    # 图像增强
+    # PIL -> CV2
+    img = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)
+
+    # 随机缩放
+    resize_y = random.randint(1, 2)
+    resize_x = random.randint(1, 2)
+    img = cv2.resize(img, (img.shape[1]*resize_y, img.shape[0]*resize_x))
+
+    # 模糊
+    # 高斯模糊
+    sigmaX = random.randint(1, 3)
+    sigmaY = random.randint(1, 3)
+    img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
+
+    # CV2 -> PIL
+    im = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+    # resize_y = random.uniform(1, 3)
+    # resize_x = random.uniform(1, 3)
+    # img = im.resize((int(im.size[0]*resize_y), int(im.size[1]*resize_x)), Image.ANTIALIAS)
+
+    # 保存
+    # cv2.imwrite(data_dir + file_name, img)
+
+    # im.show("img")
+    im.save(data_dir + file_name)
+    # print(file_name)
+
+
+def create_orgs_image(df):
+    df = df[:1000]
+    label_file_train = project_path + "train_data\\bidi_data\\orgs_data\\rec_gt_train.txt"
+    label_file_test = project_path + "train_data\\bidi_data\\orgs_data\\rec_gt_test.txt"
+    image_output_path = project_path + "train_data\\bidi_data\\orgs_data\\"
+    f1 = open(label_file_train, "w")
+    f2 = open(label_file_test, "w")
+    print(df.shape)
+    for index, row in df.iterrows():
+        text = row["name"]
+        # text = "晋江滨江国家体育训练基地有限公司"
+        im = Image.new("RGB", (len(text)*10, 16), (255, 255, 255))
+        dr = ImageDraw.Draw(im)
+        font = ImageFont.truetype(os.path.join(os.getcwd(), "fonts", "msyh.ttc"), 10)
+        dr.text((0, 0), text, font=font, fill="#000000")
+        # im.show()
+        if index / df.shape[0] <= 0.8:
+            mode = "train"
+            f = f1
+        else:
+            mode = "test"
+            f = f2
+        im.save(image_output_path + mode + "\\" + "text_" + str(index) + ".jpg")
+        f.write(mode + "/text_" + str(index) + ".jpg" + "\t" + text + "\n")
+
+    f1.close()
+    f2.close()
+
+
+def create_longSentence_image(df):
+    # df = df[:3000]
+    label_file_train = project_path + "train_data\\bidi_data\\longSentence_data\\rec_gt_train.txt"
+    label_file_test = project_path + "train_data\\bidi_data\\longSentence_data\\rec_gt_test.txt"
+    image_output_path = project_path + "train_data\\bidi_data\\longSentence_data\\"
+    f1 = open(label_file_train, "w")
+    f2 = open(label_file_test, "w")
+    print(df.shape)
+
+    for index, row in df.iterrows():
+        text = row["text"]
+        # text = "晋江滨江国家体育训练基地有限公司"
+        im = Image.new("RGB", (len(text)*10, 16), (255, 255, 255))
+        dr = ImageDraw.Draw(im)
+        font = ImageFont.truetype(os.path.join(os.getcwd(), "fonts", "msyh.ttc"), 10)
+        dr.text((0, 0), text, font=font, fill="#000000")
+        # im.show()
+        if index <= int((df.shape[0]-1)*0.8):
+            mode = "train"
+            f = f1
+        else:
+            mode = "test"
+            f = f2
+        im.save(image_output_path + mode + "\\" + "text_" + str(index) + ".jpg")
+        f.write(mode + "/text_" + str(index) + ".jpg" + "\t" + text + "\n")
+
+    f1.close()
+    f2.close()
+
+
+# def readPostgreSQL():
+#     conn_string = "host=192.168.2.101 port=5432 dbname=iepy " \
+#                   "user=iepy_read password=iepy_read"
+#     conn = psycopg2.connect(conn_string)
+#
+#     # 执行SQL语句
+#     sql = "select text from corpus_iedocument " \
+#           "where jump_signal=0"
+#     df = pd.read_sql(sql, conn)
+#     return df
+
+
+# 生成多个场景混合数据
+def create_mix_txt():
+    # 最长字符串长度
+    max_length = 100
+    # list1 = create_text_list(max_length)
+
+    list1 = create_number_list(3000000)
+    print("finish get list1", len(list1))
+    # list2 = create_org_list()
+
+    list2 = get_long_sentence_from_file(2000000)
+    print("finish get list2", len(list2))
+    # list2 = list2[0:100]
+
+    with open("appendix_text.txt", "r") as f:
+        list3 = f.readlines()
+    # list3 = list3[:6]
+    print("finish get list3", len(list3))
+
+    list4 = create_org_list()
+    # list4 = list4[:6]
+    print("finish get list4", len(list4))
+
+    train_data = list1[0:int(len(list1)*0.95)] + list2[0:int(len(list2)*0.95)] + \
+                 list3[0:int(len(list3)*0.95)] + list4[0:int(len(list4)*0.95)]
+    test_data = list1[int(len(list1)*0.95):] + list2[int(len(list2)*0.95):] + \
+                list3[int(len(list3)*0.95):] + list4[int(len(list4)*0.95):]
+    print("len(train_data)", len(train_data))
+    print("len(test_data)", len(test_data))
+
+    data_index = 0
+    with open(train_data_path, "w") as f:
+        for data in train_data:
+            prefix = "train/text_" + str(data_index) + ".jpg" + "\t"
+            data = prefix + data
+            f.write(data)
+            data_index += 1
+    print("finish write train data")
+    with open(test_data_path, "w") as f:
+        for data in test_data:
+            prefix = "test/text_" + str(data_index) + ".jpg" + "\t"
+            data = prefix + data
+            f.write(data)
+            data_index += 1
+    print("finish write test data")
+    return
+
+
+# def create_text_list(max_length):
+#     # 招投标文章语句
+#     df1 = readPostgreSQL()
+#     list1 = []
+#     for index, row in df1.iterrows():
+#         text = row["text"].split(",")
+#         # print(len(text))
+#
+#         # 每篇文章最多取10个句子
+#         max_sentence = 15
+#         sentence_count = 0
+#         while sentence_count < max_sentence:
+#             if len(text) <= max_sentence:
+#                 if sentence_count < len(text):
+#                     sentence = text[sentence_count]
+#                 else:
+#                     break
+#             else:
+#                 r1 = random.randint(0, len(text) - 1)
+#                 sentence = text[r1]
+#             if len(sentence) > max_length:
+#                 # 限制字数,随机截取前或后
+#                 r2 = random.randint(0, 1)
+#                 if r2:
+#                     sentence = sentence[:max_length]
+#                 else:
+#                     sentence = sentence[-max_length:]
+#
+#             # sentence = re.sub("\n", "", sentence)
+#             if sentence != "":
+#                 list1.append(sentence+"\n")
+#             sentence_count += 1
+#     print("len(list1)", len(list1))
+#     return list1
+
+
+def delete_image(data_dir, file_name):
+    if os.path.exists(data_dir + file_name):
+        os.remove(data_dir + file_name)
+
+
+def create_org_list():
+    # 1kw公司名
+    with open("C:\\Users\\Administrator\\Desktop\\LEGAL_ENTERPRISE.txt", "r") as f:
+        list2 = f.readlines()
+    # list2 = list2[:100]
+    # print("len(list2)", len(list2))
+    return list2
+
+
+def create_number_list(number):
+    no_list = []
+    for i in range(number):
+        # 随机选择生成几位小数
+        decimal_place = random.choices([0, 1, 2, 3, 4, 5, 6])[0]
+
+        if decimal_place == 0:
+            no = random.randint(0, 10000000)
+        else:
+            no = random.uniform(0, 10000)
+            no = round(no, decimal_place)
+        no_list.append(str(no)+"\n")
+    # print(no_list)
+    return no_list
+
+
+def get_mix_data_from_file(number):
+    with open("../../train_data/bidi_data/mix_data/rec_gt_train.txt") as f:
+        _list = f.readlines()
+    _list = _list[:number]
+
+    new_list = []
+    for line in _list:
+        s = line.split("\t")[1]
+        new_list.append(s)
+    # print(new_list)
+    return new_list
+
+
+def get_long_sentence_from_file(number):
+    with open("../../train_data/bidi_data/longSentence_data/rec_gt_train.txt") as f:
+        list1 = f.readlines()
+
+    with open("../../train_data/bidi_data/longSentence_data/rec_gt_test.txt") as f:
+        list2 = f.readlines()
+
+    _list = list1 + list2
+    _list = _list[:number]
+
+    new_list = []
+    for line in _list:
+        s = line.split("\t")[1]
+        new_list.append(s)
+    # print(new_list)
+    return new_list
+
+
+def get_data_from_appendix():
+    df = pd.read_excel("dochtmlcon.xlsx")
+    df = df
+    text_list = []
+    for index, row in df.iterrows():
+        html_text = row["dochtmlcon"]
+
+        # 创建一个BeautifulSoup解析对象
+        soup = BeautifulSoup(html_text, "html.parser", from_encoding="utf-8")
+
+        # 获取所有的链接
+        appendix_text = soup.find_all('div', class_='richTextFetch')
+        # print(str(appendix_text[0])[49:-6])
+
+        appendix_text = str(appendix_text[0])[49:-6]
+        ss = appendix_text.split("\n")
+        for s in ss:
+            text = re.sub(" ", "", s)
+            text = re.sub("\t", "", text)
+            if s == "":
+                continue
+            text_list.append(text + "\n")
+
+    with open("appendix_text.txt", "w") as f:
+        f.writelines(text_list)
+    return
+
+
+def get_data_from_paddle():
+    path = "D:\\DataSet\\"
+    with open(path + "char.txt", "r") as f:
+        dictionary = f.readlines()
+    with open(path + "data_train.txt") as f:
+        train_list = f.readlines()
+    with open(path + "data_test.txt") as f:
+        test_list = f.readlines()
+
+    data_list = train_list + test_list
+    # data_list = data_list[-100:]
+
+    text_list = []
+    for data in data_list:
+        ss = data[:-1].split(" ")
+        image_path = "image/" + ss[0]
+        text = ""
+        for num in ss[1:]:
+            char = dictionary[int(num)][:-1]
+            text += char
+        if text == "":
+            print("no text!")
+            continue
+        text_list.append(image_path + "\t" + text + "\n")
+
+    with open("paddle_data.txt", "w") as f:
+        f.writelines(text_list)
+
+
+def create_number_list2(number):
+    no_list = []
+    for i in range(number):
+        c1 = random.choice([0, 1, 1])
+        if c1:
+            no = random.randint(0, 10)
+        else:
+            no = random.randint(10, 100)
+        no_list.append(str(no) + "\n")
+    # print(no_list)
+    return no_list
+
+
+def create_number_list3(number):
+    no_list = []
+    for i in range(number):
+        # 选择小数整数
+        c1 = random.choice([0, 1, 1])
+        if c1:
+            no = random.randint(10000, 1000000000)
+            no = str(no)
+            # 加3位分割逗号
+            # 选择中英文逗号
+            c2 = random.choice([',', ',', ','])
+            for i in range(len(no)-3, 0, -3):
+                no = no[:i] + c2 + no[i:]
+        else:
+            no = random.uniform(10000, 1000000000)
+            no = str(no)
+            nos = no.split(".")
+            no_1 = nos[0]
+            no_2 = nos[1]
+            # 加3位分割逗号
+            # 选择中英文逗号
+            c2 = random.choice([',', ',', ','])
+            for i in range(len(no_1)-3, 0, -3):
+                no_1 = no_1[:i] + c2 + no_1[i:]
+            no = no_1 + "." + no_2
+
+        # 选择是否加¥符号
+        c3 = random.choice(['', '¥', "¥ "])
+        no_list.append(c3 + str(no) + "\n")
+    print(no_list)
+    return no_list
+
+
+if __name__ == '__main__':
+
+    # df = pd.read_csv('C:\\Users\\Administrator\\Desktop\\orgs.csv')
+    # create_longSentence_image(df)
+    # s = 100
+    # image = cv2.imread("text_384178.jpg")
+    # print(image.shape)
+    # list1 = create_text_list(100)
+    # for l in list1:
+    #     print(l)
+    # print(len(list1))
+
+    # create_mix_txt()
+    # get_mix_data_from_file(2)
+    # create_number_list(10)
+
+    # with open("../../train_data/bidi_data/mix_data2/rec_gt_test.txt", "r") as f:
+    #     _list = f.readlines()
+    # for line in _list:
+    #     _str = line.split("\t")[-1][:-1]
+    #     print(_str, type(_str))
+    #     create_image("../../train_data/bidi_data/mix_data2/", "", _str)
+
+    # get_data_from_appendix()
+    # get_data_from_paddle()
+    # delete_image("../../train_data/bidi_data/mix_data/", "train/text_0.jpg")
+
+    # with open("paddle_data.txt", "r") as f:
+    #     list1 = f.readlines()
+    # print(len(list1))
+    #
+    # with open(train_data_path, "r") as f:
+    #     list2 = f.readlines()
+    # train_data_list = list2 + list1[0:int(len(list1)*0.95)]
+    # with open(train_data_path, "w") as f:
+    #     f.writelines(train_data_list)
+    #
+    # with open(test_data_path, "r") as f:
+    #     list3 = f.readlines()
+    # test_data_list = list3 + list1[int(len(list1)*0.95):]
+    # with open(test_data_path, "w") as f:
+    #     f.writelines(test_data_list)
+
+    no_list = create_number_list3(2000000)
+    i = 23000000
+    train_list = []
+    for no in no_list:
+        train_list.append("train/text_" + str(i) + ".jpg" + "\t" + no)
+        i += 1
+    # print(train_list)
+
+    with open(train_data_path, "r") as f:
+        list3 = f.readlines()
+    _list = list3[:-2000000] + train_list
+    with open(train_data_path, "w") as f:
+        f.writelines(_list)

+ 42 - 0
ocr/ppocr/losses/__init__.py

@@ -0,0 +1,42 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+
+def build_loss(config):
+    # det loss
+    from .det_db_loss import DBLoss
+    from .det_east_loss import EASTLoss
+    from .det_sast_loss import SASTLoss
+
+    # rec loss
+    from .rec_ctc_loss import CTCLoss
+    from .rec_att_loss import AttentionLoss
+    from .rec_srn_loss import SRNLoss
+
+    # cls loss
+    from .cls_loss import ClsLoss
+
+    support_dict = [
+        'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
+        'SRNLoss'
+    ]
+
+    config = copy.deepcopy(config)
+    module_name = config.pop('name')
+    assert module_name in support_dict, Exception('loss only support {}'.format(
+        support_dict))
+    module_class = eval(module_name)(**config)
+    return module_class

+ 30 - 0
ocr/ppocr/losses/cls_loss.py

@@ -0,0 +1,30 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+
+
+class ClsLoss(nn.Layer):
+    def __init__(self, **kwargs):
+        super(ClsLoss, self).__init__()
+        self.loss_func = nn.CrossEntropyLoss(reduction='mean')
+
+    def __call__(self, predicts, batch):
+        label = batch[1]
+        loss = self.loss_func(input=predicts, label=label)
+        return {'loss': loss}

+ 205 - 0
ocr/ppocr/losses/det_basic_loss.py

@@ -0,0 +1,205 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+
+class BalanceLoss(nn.Layer):
+    def __init__(self,
+                 balance_loss=True,
+                 main_loss_type='DiceLoss',
+                 negative_ratio=3,
+                 return_origin=False,
+                 eps=1e-6,
+                 **kwargs):
+        """
+               The BalanceLoss for Differentiable Binarization text detection
+               args:
+                   balance_loss (bool): whether balance loss or not, default is True
+                   main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
+                       'Euclidean','BCELoss', 'MaskL1Loss'], default is  'DiceLoss'.
+                   negative_ratio (int|float): float, default is 3.
+                   return_origin (bool): whether return unbalanced loss or not, default is False.
+                   eps (float): default is 1e-6.
+               """
+        super(BalanceLoss, self).__init__()
+        self.balance_loss = balance_loss
+        self.main_loss_type = main_loss_type
+        self.negative_ratio = negative_ratio
+        self.return_origin = return_origin
+        self.eps = eps
+
+        if self.main_loss_type == "CrossEntropy":
+            self.loss = nn.CrossEntropyLoss()
+        elif self.main_loss_type == "Euclidean":
+            self.loss = nn.MSELoss()
+        elif self.main_loss_type == "DiceLoss":
+            self.loss = DiceLoss(self.eps)
+        elif self.main_loss_type == "BCELoss":
+            self.loss = BCELoss(reduction='none')
+        elif self.main_loss_type == "MaskL1Loss":
+            self.loss = MaskL1Loss(self.eps)
+        else:
+            loss_type = [
+                'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
+            ]
+            raise Exception(
+                "main_loss_type in BalanceLoss() can only be one of {}".format(
+                    loss_type))
+
+    def forward(self, pred, gt, mask=None):
+        """
+        The BalanceLoss for Differentiable Binarization text detection
+        args:
+            pred (variable): predicted feature maps.
+            gt (variable): ground truth feature maps.
+            mask (variable): masked maps.
+        return: (variable) balanced loss
+        """
+        # if self.main_loss_type in ['DiceLoss']:
+        #     # For the loss that returns to scalar value, perform ohem on the mask
+        #     mask = ohem_batch(pred, gt, mask, self.negative_ratio)
+        #     loss = self.loss(pred, gt, mask)
+        #     return loss
+
+        positive = gt * mask
+        negative = (1 - gt) * mask
+
+        positive_count = int(positive.sum())
+        negative_count = int(
+            min(negative.sum(), positive_count * self.negative_ratio))
+        loss = self.loss(pred, gt, mask=mask)
+
+        if not self.balance_loss:
+            return loss
+
+        positive_loss = positive * loss
+        negative_loss = negative * loss
+        negative_loss = paddle.reshape(negative_loss, shape=[-1])
+        if negative_count > 0:
+            sort_loss = negative_loss.sort(descending=True)
+            negative_loss = sort_loss[:negative_count]
+            # negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
+            balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
+                positive_count + negative_count + self.eps)
+        else:
+            balance_loss = positive_loss.sum() / (positive_count + self.eps)
+        if self.return_origin:
+            return balance_loss, loss
+
+        return balance_loss
+
+
+class DiceLoss(nn.Layer):
+    def __init__(self, eps=1e-6):
+        super(DiceLoss, self).__init__()
+        self.eps = eps
+
+    def forward(self, pred, gt, mask, weights=None):
+        """
+        DiceLoss function.
+        """
+
+        assert pred.shape == gt.shape
+        assert pred.shape == mask.shape
+        if weights is not None:
+            assert weights.shape == mask.shape
+            mask = weights * mask
+        intersection = paddle.sum(pred * gt * mask)
+
+        union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
+        loss = 1 - 2.0 * intersection / union
+        assert loss <= 1
+        return loss
+
+
+class MaskL1Loss(nn.Layer):
+    def __init__(self, eps=1e-6):
+        super(MaskL1Loss, self).__init__()
+        self.eps = eps
+
+    def forward(self, pred, gt, mask):
+        """
+        Mask L1 Loss
+        """
+        loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
+        loss = paddle.mean(loss)
+        return loss
+
+
+class BCELoss(nn.Layer):
+    def __init__(self, reduction='mean'):
+        super(BCELoss, self).__init__()
+        self.reduction = reduction
+
+    def forward(self, input, label, mask=None, weight=None, name=None):
+        loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
+        return loss
+
+
+def ohem_single(score, gt_text, training_mask, ohem_ratio):
+    pos_num = (int)(np.sum(gt_text > 0.5)) - (
+        int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
+
+    if pos_num == 0:
+        # selected_mask = gt_text.copy() * 0 # may be not good
+        selected_mask = training_mask
+        selected_mask = selected_mask.reshape(
+            1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+        return selected_mask
+
+    neg_num = (int)(np.sum(gt_text <= 0.5))
+    neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
+
+    if neg_num == 0:
+        selected_mask = training_mask
+        selected_mask = selected_mask.reshape(
+            1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+        return selected_mask
+
+    neg_score = score[gt_text <= 0.5]
+    # 将负样本得分从高到低排序
+    neg_score_sorted = np.sort(-neg_score)
+    threshold = -neg_score_sorted[neg_num - 1]
+    # 选出 得分高的 负样本 和正样本 的 mask
+    selected_mask = ((score >= threshold) |
+                     (gt_text > 0.5)) & (training_mask > 0.5)
+    selected_mask = selected_mask.reshape(
+        1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+    return selected_mask
+
+
+def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
+    scores = scores.numpy()
+    gt_texts = gt_texts.numpy()
+    training_masks = training_masks.numpy()
+
+    selected_masks = []
+    for i in range(scores.shape[0]):
+        selected_masks.append(
+            ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
+                i, :, :], ohem_ratio))
+
+    selected_masks = np.concatenate(selected_masks, 0)
+    selected_masks = paddle.to_variable(selected_masks)
+
+    return selected_masks

+ 72 - 0
ocr/ppocr/losses/det_db_loss.py

@@ -0,0 +1,72 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+
+from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
+
+
+class DBLoss(nn.Layer):
+    """
+    Differentiable Binarization (DB) Loss Function
+    args:
+        param (dict): the super paramter for DB Loss
+    """
+
+    def __init__(self,
+                 balance_loss=True,
+                 main_loss_type='DiceLoss',
+                 alpha=5,
+                 beta=10,
+                 ohem_ratio=3,
+                 eps=1e-6,
+                 **kwargs):
+        super(DBLoss, self).__init__()
+        self.alpha = alpha
+        self.beta = beta
+        self.dice_loss = DiceLoss(eps=eps)
+        self.l1_loss = MaskL1Loss(eps=eps)
+        self.bce_loss = BalanceLoss(
+            balance_loss=balance_loss,
+            main_loss_type=main_loss_type,
+            negative_ratio=ohem_ratio)
+
+    def forward(self, predicts, labels):
+        predict_maps = predicts['maps']
+        label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
+            1:]
+        shrink_maps = predict_maps[:, 0, :, :]
+        threshold_maps = predict_maps[:, 1, :, :]
+        binary_maps = predict_maps[:, 2, :, :]
+
+        loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
+                                         label_shrink_mask)
+        loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
+                                           label_threshold_mask)
+        loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
+                                          label_shrink_mask)
+        loss_shrink_maps = self.alpha * loss_shrink_maps
+        loss_threshold_maps = self.beta * loss_threshold_maps
+
+        loss_all = loss_shrink_maps + loss_threshold_maps \
+                   + loss_binary_maps
+        losses = {'loss': loss_all, \
+                  "loss_shrink_maps": loss_shrink_maps, \
+                  "loss_threshold_maps": loss_threshold_maps, \
+                  "loss_binary_maps": loss_binary_maps}
+        return losses

+ 63 - 0
ocr/ppocr/losses/det_east_loss.py

@@ -0,0 +1,63 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from .det_basic_loss import DiceLoss
+
+
+class EASTLoss(nn.Layer):
+    """
+    """
+
+    def __init__(self,
+                 eps=1e-6,
+                 **kwargs):
+        super(EASTLoss, self).__init__()
+        self.dice_loss = DiceLoss(eps=eps)
+
+    def forward(self, predicts, labels):
+        l_score, l_geo, l_mask = labels[1:]
+        f_score = predicts['f_score']
+        f_geo = predicts['f_geo']
+
+        dice_loss = self.dice_loss(f_score, l_score, l_mask)
+
+        #smoooth_l1_loss
+        channels = 8
+        l_geo_split = paddle.split(
+            l_geo, num_or_sections=channels + 1, axis=1)
+        f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
+        smooth_l1 = 0
+        for i in range(0, channels):
+            geo_diff = l_geo_split[i] - f_geo_split[i]
+            abs_geo_diff = paddle.abs(geo_diff)
+            smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
+            smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
+            in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
+                (abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
+            out_loss = l_geo_split[-1] / channels * in_loss * l_score
+            smooth_l1 += out_loss
+        smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
+
+        dice_loss = dice_loss * 0.01
+        total_loss = dice_loss + smooth_l1_loss
+        losses = {"loss":total_loss, \
+                  "dice_loss":dice_loss,\
+                  "smooth_l1_loss":smooth_l1_loss}
+        return losses

+ 121 - 0
ocr/ppocr/losses/det_sast_loss.py

@@ -0,0 +1,121 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from .det_basic_loss import DiceLoss
+import numpy as np
+
+
+class SASTLoss(nn.Layer):
+    """
+    """
+
+    def __init__(self, eps=1e-6, **kwargs):
+        super(SASTLoss, self).__init__()
+        self.dice_loss = DiceLoss(eps=eps)
+
+    def forward(self, predicts, labels):
+        """
+        tcl_pos: N x 128 x 3
+        tcl_mask: N x 128 x 1
+        tcl_label: N x X list or LoDTensor
+        """
+
+        f_score = predicts['f_score']
+        f_border = predicts['f_border']
+        f_tvo = predicts['f_tvo']
+        f_tco = predicts['f_tco']
+
+        l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
+
+        #score_loss
+        intersection = paddle.sum(f_score * l_score * l_mask)
+        union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
+        score_loss = 1.0 - 2 * intersection / (union + 1e-5)
+
+        #border loss
+        l_border_split, l_border_norm = paddle.split(
+            l_border, num_or_sections=[4, 1], axis=1)
+        f_border_split = f_border
+        border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
+        l_border_norm_split = paddle.expand(
+            x=l_border_norm, shape=border_ex_shape)
+        l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
+        l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
+
+        border_diff = l_border_split - f_border_split
+        abs_border_diff = paddle.abs(border_diff)
+        border_sign = abs_border_diff < 1.0
+        border_sign = paddle.cast(border_sign, dtype='float32')
+        border_sign.stop_gradient = True
+        border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
+                    (abs_border_diff - 0.5) * (1.0 - border_sign)
+        border_out_loss = l_border_norm_split * border_in_loss
+        border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
+                    (paddle.sum(l_border_score * l_border_mask) + 1e-5)
+
+        #tvo_loss
+        l_tvo_split, l_tvo_norm = paddle.split(
+            l_tvo, num_or_sections=[8, 1], axis=1)
+        f_tvo_split = f_tvo
+        tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
+        l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
+        l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
+        l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
+        #
+        tvo_geo_diff = l_tvo_split - f_tvo_split
+        abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
+        tvo_sign = abs_tvo_geo_diff < 1.0
+        tvo_sign = paddle.cast(tvo_sign, dtype='float32')
+        tvo_sign.stop_gradient = True
+        tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
+                    (abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
+        tvo_out_loss = l_tvo_norm_split * tvo_in_loss
+        tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
+                    (paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
+
+        #tco_loss
+        l_tco_split, l_tco_norm = paddle.split(
+            l_tco, num_or_sections=[2, 1], axis=1)
+        f_tco_split = f_tco
+        tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
+        l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
+        l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
+        l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
+
+        tco_geo_diff = l_tco_split - f_tco_split
+        abs_tco_geo_diff = paddle.abs(tco_geo_diff)
+        tco_sign = abs_tco_geo_diff < 1.0
+        tco_sign = paddle.cast(tco_sign, dtype='float32')
+        tco_sign.stop_gradient = True
+        tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
+                    (abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
+        tco_out_loss = l_tco_norm_split * tco_in_loss
+        tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
+                    (paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
+
+        # total loss
+        tvo_lw, tco_lw = 1.5, 1.5
+        score_lw, border_lw = 1.0, 1.0
+        total_loss = score_loss * score_lw + border_loss * border_lw + \
+                    tvo_loss * tvo_lw + tco_loss * tco_lw
+
+        losses = {'loss':total_loss, "score_loss":score_loss,\
+            "border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
+        return losses

+ 39 - 0
ocr/ppocr/losses/rec_att_loss.py

@@ -0,0 +1,39 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class AttentionLoss(nn.Layer):
+    def __init__(self, **kwargs):
+        super(AttentionLoss, self).__init__()
+        self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+
+    def forward(self, predicts, batch):
+        targets = batch[1].astype("int64")
+        label_lengths = batch[2].astype('int64')
+        batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
+            1], predicts.shape[2]
+        assert len(targets.shape) == len(list(predicts.shape)) - 1, \
+            "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+        inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
+        targets = paddle.reshape(targets, [-1])
+
+        return {'loss': paddle.sum(self.loss_func(inputs, targets))}

+ 36 - 0
ocr/ppocr/losses/rec_ctc_loss.py

@@ -0,0 +1,36 @@
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class CTCLoss(nn.Layer):
+    def __init__(self, **kwargs):
+        super(CTCLoss, self).__init__()
+        self.loss_func = nn.CTCLoss(blank=0, reduction='none')
+
+    def __call__(self, predicts, batch):
+        predicts = predicts.transpose((1, 0, 2))
+        N, B, _ = predicts.shape
+        preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
+        labels = batch[1].astype("int32")
+        label_lengths = batch[2].astype('int64')
+        loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
+        loss = loss.mean()  # sum
+        return {'loss': loss}

+ 47 - 0
ocr/ppocr/losses/rec_srn_loss.py

@@ -0,0 +1,47 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SRNLoss(nn.Layer):
+    def __init__(self, **kwargs):
+        super(SRNLoss, self).__init__()
+        self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
+
+    def forward(self, predicts, batch):
+        predict = predicts['predict']
+        word_predict = predicts['word_out']
+        gsrm_predict = predicts['gsrm_out']
+        label = batch[1]
+
+        casted_label = paddle.cast(x=label, dtype='int64')
+        casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
+
+        cost_word = self.loss_func(word_predict, label=casted_label)
+        cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
+        cost_vsfd = self.loss_func(predict, label=casted_label)
+
+        cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
+        cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
+        cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
+
+        sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
+
+        return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}

+ 37 - 0
ocr/ppocr/metrics/__init__.py

@@ -0,0 +1,37 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import copy
+
+__all__ = ['build_metric']
+
+
+def build_metric(config):
+    from .det_metric import DetMetric
+    from .rec_metric import RecMetric
+    from .cls_metric import ClsMetric
+
+    support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
+
+    config = copy.deepcopy(config)
+    module_name = config.pop('name')
+    assert module_name in support_dict, Exception(
+        'metric only support {}'.format(support_dict))
+    module_class = eval(module_name)(**config)
+    return module_class

+ 45 - 0
ocr/ppocr/metrics/cls_metric.py

@@ -0,0 +1,45 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class ClsMetric(object):
+    def __init__(self, main_indicator='acc', **kwargs):
+        self.main_indicator = main_indicator
+        self.reset()
+
+    def __call__(self, pred_label, *args, **kwargs):
+        preds, labels = pred_label
+        correct_num = 0
+        all_num = 0
+        for (pred, pred_conf), (target, _) in zip(preds, labels):
+            if pred == target:
+                correct_num += 1
+            all_num += 1
+        self.correct_num += correct_num
+        self.all_num += all_num
+        return {'acc': correct_num / all_num, }
+
+    def get_metric(self):
+        """
+        return metrics {
+                 'acc': 0
+            }
+        """
+        acc = self.correct_num / self.all_num
+        self.reset()
+        return {'acc': acc}
+
+    def reset(self):
+        self.correct_num = 0
+        self.all_num = 0

+ 72 - 0
ocr/ppocr/metrics/det_metric.py

@@ -0,0 +1,72 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+__all__ = ['DetMetric']
+
+from .eval_det_iou import DetectionIoUEvaluator
+
+
+class DetMetric(object):
+    def __init__(self, main_indicator='hmean', **kwargs):
+        self.evaluator = DetectionIoUEvaluator()
+        self.main_indicator = main_indicator
+        self.reset()
+
+    def __call__(self, preds, batch, **kwargs):
+        '''
+       batch: a list produced by dataloaders.
+           image: np.ndarray  of shape (N, C, H, W).
+           ratio_list: np.ndarray  of shape(N,2)
+           polygons: np.ndarray  of shape (N, K, 4, 2), the polygons of objective regions.
+           ignore_tags: np.ndarray  of shape (N, K), indicates whether a region is ignorable or not.
+       preds: a list of dict produced by post process
+            points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+       '''
+        gt_polyons_batch = batch[2]
+        ignore_tags_batch = batch[3]
+        for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
+                                                 ignore_tags_batch):
+            # prepare gt
+            gt_info_list = [{
+                'points': gt_polyon,
+                'text': '',
+                'ignore': ignore_tag
+            } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
+            # prepare det
+            det_info_list = [{
+                'points': det_polyon,
+                'text': ''
+            } for det_polyon in pred['points']]
+            result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
+            self.results.append(result)
+
+    def get_metric(self):
+        """
+        return metrics {
+                 'precision': 0,
+                 'recall': 0,
+                 'hmean': 0
+            }
+        """
+
+        metircs = self.evaluator.combine_results(self.results)
+        self.reset()
+        return metircs
+
+    def reset(self):
+        self.results = []  # clear results

+ 235 - 0
ocr/ppocr/metrics/eval_det_iou.py

@@ -0,0 +1,235 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from collections import namedtuple
+import numpy as np
+from shapely.geometry import Polygon
+"""
+reference from :
+https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
+"""
+
+
+class DetectionIoUEvaluator(object):
+    def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
+        self.iou_constraint = iou_constraint
+        self.area_precision_constraint = area_precision_constraint
+
+    def evaluate_image(self, gt, pred):
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / get_union(pD, pG)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        def compute_ap(confList, matchList, numGtCare):
+            correct = 0
+            AP = 0
+            if len(confList) > 0:
+                confList = np.array(confList)
+                matchList = np.array(matchList)
+                sorted_ind = np.argsort(-confList)
+                confList = confList[sorted_ind]
+                matchList = matchList[sorted_ind]
+                for n in range(len(confList)):
+                    match = matchList[n]
+                    if match:
+                        correct += 1
+                        AP += float(correct) / (n + 1)
+
+                if numGtCare > 0:
+                    AP /= numGtCare
+
+            return AP
+
+        perSampleMetrics = {}
+
+        matchedSum = 0
+
+        Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
+
+        numGlobalCareGt = 0
+        numGlobalCareDet = 0
+
+        arrGlobalConfidences = []
+        arrGlobalMatches = []
+
+        recall = 0
+        precision = 0
+        hmean = 0
+
+        detMatched = 0
+
+        iouMat = np.empty([1, 1])
+
+        gtPols = []
+        detPols = []
+
+        gtPolPoints = []
+        detPolPoints = []
+
+        # Array of Ground Truth Polygons' keys marked as don't Care
+        gtDontCarePolsNum = []
+        # Array of Detected Polygons' matched with a don't Care GT
+        detDontCarePolsNum = []
+
+        pairs = []
+        detMatchedNums = []
+
+        arrSampleConfidences = []
+        arrSampleMatch = []
+
+        evaluationLog = ""
+
+        # print(len(gt))
+        for n in range(len(gt)):
+            points = gt[n]['points']
+            # transcription = gt[n]['text']
+            dontCare = gt[n]['ignore']
+            #             points = Polygon(points)
+            #             points = points.buffer(0)
+            if not Polygon(points).is_valid or not Polygon(points).is_simple:
+                continue
+
+            gtPol = points
+            gtPols.append(gtPol)
+            gtPolPoints.append(points)
+            if dontCare:
+                gtDontCarePolsNum.append(len(gtPols) - 1)
+
+        evaluationLog += "GT polygons: " + str(len(gtPols)) + (
+            " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
+            if len(gtDontCarePolsNum) > 0 else "\n")
+
+        for n in range(len(pred)):
+            points = pred[n]['points']
+            #             points = Polygon(points)
+            #             points = points.buffer(0)
+            if not Polygon(points).is_valid or not Polygon(points).is_simple:
+                continue
+
+            detPol = points
+            detPols.append(detPol)
+            detPolPoints.append(points)
+            if len(gtDontCarePolsNum) > 0:
+                for dontCarePol in gtDontCarePolsNum:
+                    dontCarePol = gtPols[dontCarePol]
+                    intersected_area = get_intersection(dontCarePol, detPol)
+                    pdDimensions = Polygon(detPol).area
+                    precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
+                    if (precision > self.area_precision_constraint):
+                        detDontCarePolsNum.append(len(detPols) - 1)
+                        break
+
+        evaluationLog += "DET polygons: " + str(len(detPols)) + (
+            " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
+            if len(detDontCarePolsNum) > 0 else "\n")
+
+        if len(gtPols) > 0 and len(detPols) > 0:
+            # Calculate IoU and precision matrixs
+            outputShape = [len(gtPols), len(detPols)]
+            iouMat = np.empty(outputShape)
+            gtRectMat = np.zeros(len(gtPols), np.int8)
+            detRectMat = np.zeros(len(detPols), np.int8)
+            for gtNum in range(len(gtPols)):
+                for detNum in range(len(detPols)):
+                    pG = gtPols[gtNum]
+                    pD = detPols[detNum]
+                    iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
+
+            for gtNum in range(len(gtPols)):
+                for detNum in range(len(detPols)):
+                    if gtRectMat[gtNum] == 0 and detRectMat[
+                            detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
+                        if iouMat[gtNum, detNum] > self.iou_constraint:
+                            gtRectMat[gtNum] = 1
+                            detRectMat[detNum] = 1
+                            detMatched += 1
+                            pairs.append({'gt': gtNum, 'det': detNum})
+                            detMatchedNums.append(detNum)
+                            evaluationLog += "Match GT #" + \
+                                str(gtNum) + " with Det #" + str(detNum) + "\n"
+
+        numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
+        numDetCare = (len(detPols) - len(detDontCarePolsNum))
+        if numGtCare == 0:
+            recall = float(1)
+            precision = float(0) if numDetCare > 0 else float(1)
+        else:
+            recall = float(detMatched) / numGtCare
+            precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
+
+        hmean = 0 if (precision + recall) == 0 else 2.0 * \
+            precision * recall / (precision + recall)
+
+        matchedSum += detMatched
+        numGlobalCareGt += numGtCare
+        numGlobalCareDet += numDetCare
+
+        perSampleMetrics = {
+            'precision': precision,
+            'recall': recall,
+            'hmean': hmean,
+            'pairs': pairs,
+            'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
+            'gtPolPoints': gtPolPoints,
+            'detPolPoints': detPolPoints,
+            'gtCare': numGtCare,
+            'detCare': numDetCare,
+            'gtDontCare': gtDontCarePolsNum,
+            'detDontCare': detDontCarePolsNum,
+            'detMatched': detMatched,
+            'evaluationLog': evaluationLog
+        }
+
+        return perSampleMetrics
+
+    def combine_results(self, results):
+        numGlobalCareGt = 0
+        numGlobalCareDet = 0
+        matchedSum = 0
+        for result in results:
+            numGlobalCareGt += result['gtCare']
+            numGlobalCareDet += result['detCare']
+            matchedSum += result['detMatched']
+
+        methodRecall = 0 if numGlobalCareGt == 0 else float(
+            matchedSum) / numGlobalCareGt
+        methodPrecision = 0 if numGlobalCareDet == 0 else float(
+            matchedSum) / numGlobalCareDet
+        methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
+            methodRecall * methodPrecision / (methodRecall + methodPrecision)
+        # print(methodRecall, methodPrecision, methodHmean)
+        # sys.exit(-1)
+        methodMetrics = {
+            'precision': methodPrecision,
+            'recall': methodRecall,
+            'hmean': methodHmean
+        }
+
+        return methodMetrics
+
+
+if __name__ == '__main__':
+    evaluator = DetectionIoUEvaluator()
+    gts = [[{
+        'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
+        'text': 1234,
+        'ignore': False,
+    }, {
+        'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
+        'text': 5678,
+        'ignore': False,
+    }]]
+    preds = [[{
+        'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+        'text': 123,
+        'ignore': False,
+    }]]
+    results = []
+    for gt, pred in zip(gts, preds):
+        results.append(evaluator.evaluate_image(gt, pred))
+    metrics = evaluator.combine_results(results)
+    print(metrics)

+ 62 - 0
ocr/ppocr/metrics/rec_metric.py

@@ -0,0 +1,62 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import Levenshtein
+
+
+class RecMetric(object):
+    def __init__(self, main_indicator='acc', **kwargs):
+        self.main_indicator = main_indicator
+        self.reset()
+
+    def __call__(self, pred_label, *args, **kwargs):
+        preds, labels = pred_label
+        correct_num = 0
+        all_num = 0
+        norm_edit_dis = 0.0
+        for (pred, pred_conf), (target, _) in zip(preds, labels):
+            pred = pred.replace(" ", "")
+            target = target.replace(" ", "")
+            norm_edit_dis += Levenshtein.distance(pred, target) / max(
+                len(pred), len(target), 1)
+
+            # print("pred", pred)
+            # print("target", target)
+            if pred == target:
+                correct_num += 1
+            all_num += 1
+        self.correct_num += correct_num
+        self.all_num += all_num
+        self.norm_edit_dis += norm_edit_dis
+        return {
+            'acc': correct_num / all_num,
+            'norm_edit_dis': 1 - norm_edit_dis / all_num
+        }
+
+    def get_metric(self):
+        """
+        return metrics {
+                 'acc': 0,
+                 'norm_edit_dis': 0,
+            }
+        """
+        acc = 1.0 * self.correct_num / self.all_num
+        norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
+        self.reset()
+        return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
+
+    def reset(self):
+        self.correct_num = 0
+        self.all_num = 0
+        self.norm_edit_dis = 0

+ 25 - 0
ocr/ppocr/modeling/architectures/__init__.py

@@ -0,0 +1,25 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+__all__ = ['build_model']
+
+
+def build_model(config):
+    from .base_model import BaseModel
+    
+    config = copy.deepcopy(config)
+    module_class = BaseModel(config)
+    return module_class

+ 85 - 0
ocr/ppocr/modeling/architectures/base_model.py

@@ -0,0 +1,85 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from ppocr.modeling.transforms import build_transform
+from ppocr.modeling.backbones import build_backbone
+from ppocr.modeling.necks import build_neck
+from ppocr.modeling.heads import build_head
+
+__all__ = ['BaseModel']
+
+
+class BaseModel(nn.Layer):
+    def __init__(self, config):
+        """
+        the module for OCR.
+        args:
+            config (dict): the super parameters for module.
+        """
+        super(BaseModel, self).__init__()
+
+        # 输入的通道数
+        in_channels = config.get('in_channels', 3)
+        model_type = config['model_type']
+        # build transfrom,
+        # for rec, transfrom can be TPS,None
+        # for det and cls, transfrom shoule to be None,
+        # if you make model differently, you can use transfrom in det and cls
+        if 'Transform' not in config or config['Transform'] is None:
+            self.use_transform = False
+        else:
+            self.use_transform = True
+            config['Transform']['in_channels'] = in_channels
+            self.transform = build_transform(config['Transform'])
+            in_channels = self.transform.out_channels
+
+        # build backbone, backbone is need for del, rec and cls
+        # 读取backbone配置,返回对应backbone class
+        config["Backbone"]['in_channels'] = in_channels
+        self.backbone = build_backbone(config["Backbone"], model_type)
+        in_channels = self.backbone.out_channels
+
+        # build neck
+        # for rec, neck can be cnn,rnn or reshape(None)
+        # for det, neck can be FPN, BIFPN and so on.
+        # for cls, neck should be none
+        if 'Neck' not in config or config['Neck'] is None:
+            self.use_neck = False
+        else:
+            self.use_neck = True
+            config['Neck']['in_channels'] = in_channels
+            self.neck = build_neck(config['Neck'])
+            in_channels = self.neck.out_channels
+
+        # # build head, head is need for det, rec and cls
+        config["Head"]['in_channels'] = in_channels
+        self.head = build_head(config["Head"])
+
+    @paddle.jit.to_static
+    def forward(self, x, data=None):
+        if self.use_transform:
+            x = self.transform(x)
+        x = self.backbone(x)
+        if self.use_neck:
+            x = self.neck(x)
+        if data is None:
+            x = self.head(x)
+        else:
+            x = self.head(x, data)
+        return x

+ 37 - 0
ocr/ppocr/modeling/backbones/__init__.py

@@ -0,0 +1,37 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ['build_backbone']
+
+
+def build_backbone(config, model_type):
+    if model_type == 'det':
+        from .det_mobilenet_v3 import MobileNetV3
+        from .det_resnet_vd import ResNet
+        from .det_resnet_vd_sast import ResNet_SAST
+        support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
+    elif model_type == 'rec' or model_type == 'cls':
+        from .rec_mobilenet_v3 import MobileNetV3
+        from .rec_resnet_vd import ResNet
+        from .rec_resnet_fpn import ResNetFPN
+        support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
+    else:
+        raise NotImplementedError
+
+    module_name = config.pop('name')
+    assert module_name in support_dict, Exception(
+        'when model typs is {}, backbone only support {}'.format(model_type,
+                                                                 support_dict))
+    module_class = eval(module_name)(**config)
+    return module_class

+ 287 - 0
ocr/ppocr/modeling/backbones/det_mobilenet_v3.py

@@ -0,0 +1,287 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+__all__ = ['MobileNetV3']
+
+
+def make_divisible(v, divisor=8, min_value=None):
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class MobileNetV3(nn.Layer):
+    def __init__(self,
+                 in_channels=3,
+                 model_name='large',
+                 scale=0.5,
+                 disable_se=False,
+                 **kwargs):
+        """
+        the MobilenetV3 backbone network for detection module.
+        Args:
+            params(dict): the super parameters for build network
+        """
+        super(MobileNetV3, self).__init__()
+
+        self.disable_se = disable_se
+
+        if model_name == "large":
+            cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, False, 'relu', 1],
+                [3, 64, 24, False, 'relu', 2],
+                [3, 72, 24, False, 'relu', 1],
+                [5, 72, 40, True, 'relu', 2],
+                [5, 120, 40, True, 'relu', 1],
+                [5, 120, 40, True, 'relu', 1],
+                [3, 240, 80, False, 'hardswish', 2],
+                [3, 200, 80, False, 'hardswish', 1],
+                [3, 184, 80, False, 'hardswish', 1],
+                [3, 184, 80, False, 'hardswish', 1],
+                [3, 480, 112, True, 'hardswish', 1],
+                [3, 672, 112, True, 'hardswish', 1],
+                [5, 672, 160, True, 'hardswish', 2],
+                [5, 960, 160, True, 'hardswish', 1],
+                [5, 960, 160, True, 'hardswish', 1],
+            ]
+            cls_ch_squeeze = 960
+        elif model_name == "small":
+            cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, True, 'relu', 2],
+                [3, 72, 24, False, 'relu', 2],
+                [3, 88, 24, False, 'relu', 1],
+                [5, 96, 40, True, 'hardswish', 2],
+                [5, 240, 40, True, 'hardswish', 1],
+                [5, 240, 40, True, 'hardswish', 1],
+                [5, 120, 48, True, 'hardswish', 1],
+                [5, 144, 48, True, 'hardswish', 1],
+                [5, 288, 96, True, 'hardswish', 2],
+                [5, 576, 96, True, 'hardswish', 1],
+                [5, 576, 96, True, 'hardswish', 1],
+            ]
+            cls_ch_squeeze = 576
+        else:
+            raise NotImplementedError("mode[" + model_name +
+                                      "_model] is not implemented!")
+
+        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+        assert scale in supported_scale, \
+            "supported scale are {} but input scale is {}".format(supported_scale, scale)
+        inplanes = 16
+        # conv1
+        self.conv = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=make_divisible(inplanes * scale),
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            groups=1,
+            if_act=True,
+            act='hardswish',
+            name='conv1')
+
+        self.stages = []
+        self.out_channels = []
+        block_list = []
+        i = 0
+        inplanes = make_divisible(inplanes * scale)
+        for (k, exp, c, se, nl, s) in cfg:
+            se = se and not self.disable_se
+            start_idx = 2 if model_name == 'large' else 0
+            if s == 2 and i > start_idx:
+                self.out_channels.append(inplanes)
+                self.stages.append(nn.Sequential(*block_list))
+                block_list = []
+            block_list.append(
+                ResidualUnit(
+                    in_channels=inplanes,
+                    mid_channels=make_divisible(scale * exp),
+                    out_channels=make_divisible(scale * c),
+                    kernel_size=k,
+                    stride=s,
+                    use_se=se,
+                    act=nl,
+                    name="conv" + str(i + 2)))
+            inplanes = make_divisible(scale * c)
+            i += 1
+        block_list.append(
+            ConvBNLayer(
+                in_channels=inplanes,
+                out_channels=make_divisible(scale * cls_ch_squeeze),
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                groups=1,
+                if_act=True,
+                act='hardswish',
+                name='conv_last'))
+        self.stages.append(nn.Sequential(*block_list))
+        self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
+        for i, stage in enumerate(self.stages):
+            self.add_sublayer(sublayer=stage, name="stage{}".format(i))
+
+    def forward(self, x):
+        x = self.conv(x)
+        out_list = []
+        for stage in self.stages:
+            x = stage(x)
+            out_list.append(x)
+        return out_list
+
+
+class ConvBNLayer(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride,
+                 padding,
+                 groups=1,
+                 if_act=True,
+                 act=None,
+                 name=None):
+        super(ConvBNLayer, self).__init__()
+        self.if_act = if_act
+        self.act = act
+        self.conv = nn.Conv2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            groups=groups,
+            weight_attr=ParamAttr(name=name + '_weights', trainable=False),
+            bias_attr=False)
+
+        self.bn = nn.BatchNorm(
+            num_channels=out_channels,
+            act=None,
+            param_attr=ParamAttr(name=name + "_bn_scale", trainable=False),
+            bias_attr=ParamAttr(name=name + "_bn_offset", trainable=False),
+            moving_mean_name=name + "_bn_mean",
+            moving_variance_name=name + "_bn_variance")
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.if_act:
+            if self.act == "relu":
+                x = F.relu(x)
+            elif self.act == "hardswish":
+                x = F.hardswish(x)
+            else:
+                print("The activation function({}) is selected incorrectly.".
+                      format(self.act))
+                exit()
+        return x
+
+
+class ResidualUnit(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 mid_channels,
+                 out_channels,
+                 kernel_size,
+                 stride,
+                 use_se,
+                 act=None,
+                 name=''):
+        super(ResidualUnit, self).__init__()
+        self.if_shortcut = stride == 1 and in_channels == out_channels
+        self.if_se = use_se
+
+        self.expand_conv = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=mid_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            if_act=True,
+            act=act,
+            name=name + "_expand")
+        self.bottleneck_conv = ConvBNLayer(
+            in_channels=mid_channels,
+            out_channels=mid_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=int((kernel_size - 1) // 2),
+            groups=mid_channels,
+            if_act=True,
+            act=act,
+            name=name + "_depthwise")
+        if self.if_se:
+            self.mid_se = SEModule(mid_channels, name=name + "_se")
+        self.linear_conv = ConvBNLayer(
+            in_channels=mid_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            if_act=False,
+            act=None,
+            name=name + "_linear")
+
+    def forward(self, inputs):
+        x = self.expand_conv(inputs)
+        x = self.bottleneck_conv(x)
+        if self.if_se:
+            x = self.mid_se(x)
+        x = self.linear_conv(x)
+        if self.if_shortcut:
+            x = paddle.add(inputs, x)
+        return x
+
+
+class SEModule(nn.Layer):
+    def __init__(self, in_channels, reduction=4, name=""):
+        super(SEModule, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2D(1)
+        self.conv1 = nn.Conv2D(
+            in_channels=in_channels,
+            out_channels=in_channels // reduction,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            weight_attr=ParamAttr(name=name + "_1_weights", trainable=False),
+            bias_attr=ParamAttr(name=name + "_1_offset", trainable=False))
+        self.conv2 = nn.Conv2D(
+            in_channels=in_channels // reduction,
+            out_channels=in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            weight_attr=ParamAttr(name + "_2_weights", trainable=False),
+            bias_attr=ParamAttr(name=name + "_2_offset", trainable=False))
+
+    def forward(self, inputs):
+        outputs = self.avg_pool(inputs)
+        outputs = self.conv1(outputs)
+        outputs = F.relu(outputs)
+        outputs = self.conv2(outputs)
+        outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
+        return inputs * outputs

+ 280 - 0
ocr/ppocr/modeling/backbones/det_resnet_vd.py

@@ -0,0 +1,280 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet"]
+
+
+class ConvBNLayer(nn.Layer):
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            groups=1,
+            is_vd_mode=False,
+            act=None,
+            name=None, ):
+        super(ConvBNLayer, self).__init__()
+
+        self.is_vd_mode = is_vd_mode
+        self._pool2d_avg = nn.AvgPool2D(
+            kernel_size=2, stride=2, padding=0, ceil_mode=True)
+        self._conv = nn.Conv2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=(kernel_size - 1) // 2,
+            groups=groups,
+            weight_attr=ParamAttr(name=name + "_weights"),
+            bias_attr=False)
+        if name == "conv1":
+            bn_name = "bn_" + name
+        else:
+            bn_name = "bn" + name[3:]
+        self._batch_norm = nn.BatchNorm(
+            out_channels,
+            act=act,
+            param_attr=ParamAttr(name=bn_name + '_scale'),
+            bias_attr=ParamAttr(bn_name + '_offset'),
+            moving_mean_name=bn_name + '_mean',
+            moving_variance_name=bn_name + '_variance')
+
+    def forward(self, inputs):
+        if self.is_vd_mode:
+            inputs = self._pool2d_avg(inputs)
+        y = self._conv(inputs)
+        y = self._batch_norm(y)
+        return y
+
+
+class BottleneckBlock(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 shortcut=True,
+                 if_first=False,
+                 name=None):
+        super(BottleneckBlock, self).__init__()
+
+        self.conv0 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            act='relu',
+            name=name + "_branch2a")
+        self.conv1 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            stride=stride,
+            act='relu',
+            name=name + "_branch2b")
+        self.conv2 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels * 4,
+            kernel_size=1,
+            act=None,
+            name=name + "_branch2c")
+
+        if not shortcut:
+            self.short = ConvBNLayer(
+                in_channels=in_channels,
+                out_channels=out_channels * 4,
+                kernel_size=1,
+                stride=1,
+                is_vd_mode=False if if_first else True,
+                name=name + "_branch1")
+
+        self.shortcut = shortcut
+
+    def forward(self, inputs):
+        y = self.conv0(inputs)
+        conv1 = self.conv1(y)
+        conv2 = self.conv2(conv1)
+
+        if self.shortcut:
+            short = inputs
+        else:
+            short = self.short(inputs)
+        y = paddle.add(x=short, y=conv2)
+        y = F.relu(y)
+        return y
+
+
+class BasicBlock(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 shortcut=True,
+                 if_first=False,
+                 name=None):
+        super(BasicBlock, self).__init__()
+        self.stride = stride
+        self.conv0 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            stride=stride,
+            act='relu',
+            name=name + "_branch2a")
+        self.conv1 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            act=None,
+            name=name + "_branch2b")
+
+        if not shortcut:
+            self.short = ConvBNLayer(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+                stride=1,
+                is_vd_mode=False if if_first else True,
+                name=name + "_branch1")
+
+        self.shortcut = shortcut
+
+    def forward(self, inputs):
+        y = self.conv0(inputs)
+        conv1 = self.conv1(y)
+
+        if self.shortcut:
+            short = inputs
+        else:
+            short = self.short(inputs)
+        y = paddle.add(x=short, y=conv1)
+        y = F.relu(y)
+        return y
+
+
+class ResNet(nn.Layer):
+    def __init__(self, in_channels=3, layers=50, **kwargs):
+        super(ResNet, self).__init__()
+
+        self.layers = layers
+        supported_layers = [18, 34, 50, 101, 152, 200]
+        assert layers in supported_layers, \
+            "supported layers are {} but input layer is {}".format(
+                supported_layers, layers)
+
+        if layers == 18:
+            depth = [2, 2, 2, 2]
+        elif layers == 34 or layers == 50:
+            depth = [3, 4, 6, 3]
+        elif layers == 101:
+            depth = [3, 4, 23, 3]
+        elif layers == 152:
+            depth = [3, 8, 36, 3]
+        elif layers == 200:
+            depth = [3, 12, 48, 3]
+        num_channels = [64, 256, 512,
+                        1024] if layers >= 50 else [64, 64, 128, 256]
+        num_filters = [64, 128, 256, 512]
+
+        self.conv1_1 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=32,
+            kernel_size=3,
+            stride=2,
+            act='relu',
+            name="conv1_1")
+        self.conv1_2 = ConvBNLayer(
+            in_channels=32,
+            out_channels=32,
+            kernel_size=3,
+            stride=1,
+            act='relu',
+            name="conv1_2")
+        self.conv1_3 = ConvBNLayer(
+            in_channels=32,
+            out_channels=64,
+            kernel_size=3,
+            stride=1,
+            act='relu',
+            name="conv1_3")
+        self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+        self.stages = []
+        self.out_channels = []
+        if layers >= 50:
+            for block in range(len(depth)):
+                block_list = []
+                shortcut = False
+                for i in range(depth[block]):
+                    if layers in [101, 152] and block == 2:
+                        if i == 0:
+                            conv_name = "res" + str(block + 2) + "a"
+                        else:
+                            conv_name = "res" + str(block + 2) + "b" + str(i)
+                    else:
+                        conv_name = "res" + str(block + 2) + chr(97 + i)
+                    bottleneck_block = self.add_sublayer(
+                        'bb_%d_%d' % (block, i),
+                        BottleneckBlock(
+                            in_channels=num_channels[block]
+                            if i == 0 else num_filters[block] * 4,
+                            out_channels=num_filters[block],
+                            stride=2 if i == 0 and block != 0 else 1,
+                            shortcut=shortcut,
+                            if_first=block == i == 0,
+                            name=conv_name))
+                    shortcut = True
+                    block_list.append(bottleneck_block)
+                self.out_channels.append(num_filters[block] * 4)
+                self.stages.append(nn.Sequential(*block_list))
+        else:
+            for block in range(len(depth)):
+                block_list = []
+                shortcut = False
+                for i in range(depth[block]):
+                    conv_name = "res" + str(block + 2) + chr(97 + i)
+                    basic_block = self.add_sublayer(
+                        'bb_%d_%d' % (block, i),
+                        BasicBlock(
+                            in_channels=num_channels[block]
+                            if i == 0 else num_filters[block],
+                            out_channels=num_filters[block],
+                            stride=2 if i == 0 and block != 0 else 1,
+                            shortcut=shortcut,
+                            if_first=block == i == 0,
+                            name=conv_name))
+                    shortcut = True
+                    block_list.append(basic_block)
+                self.out_channels.append(num_filters[block])
+                self.stages.append(nn.Sequential(*block_list))
+
+    def forward(self, inputs):
+        y = self.conv1_1(inputs)
+        y = self.conv1_2(y)
+        y = self.conv1_3(y)
+        y = self.pool2d_max(y)
+        out = []
+        for block in self.stages:
+            y = block(y)
+            out.append(y)
+        return out

+ 285 - 0
ocr/ppocr/modeling/backbones/det_resnet_vd_sast.py

@@ -0,0 +1,285 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet_SAST"]
+
+
+class ConvBNLayer(nn.Layer):
+    def __init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            groups=1,
+            is_vd_mode=False,
+            act=None,
+            name=None, ):
+        super(ConvBNLayer, self).__init__()
+
+        self.is_vd_mode = is_vd_mode
+        self._pool2d_avg = nn.AvgPool2D(
+            kernel_size=2, stride=2, padding=0, ceil_mode=True)
+        self._conv = nn.Conv2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=(kernel_size - 1) // 2,
+            groups=groups,
+            weight_attr=ParamAttr(name=name + "_weights"),
+            bias_attr=False)
+        if name == "conv1":
+            bn_name = "bn_" + name
+        else:
+            bn_name = "bn" + name[3:]
+        self._batch_norm = nn.BatchNorm(
+            out_channels,
+            act=act,
+            param_attr=ParamAttr(name=bn_name + '_scale'),
+            bias_attr=ParamAttr(bn_name + '_offset'),
+            moving_mean_name=bn_name + '_mean',
+            moving_variance_name=bn_name + '_variance')
+
+    def forward(self, inputs):
+        if self.is_vd_mode:
+            inputs = self._pool2d_avg(inputs)
+        y = self._conv(inputs)
+        y = self._batch_norm(y)
+        return y
+
+
+class BottleneckBlock(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 shortcut=True,
+                 if_first=False,
+                 name=None):
+        super(BottleneckBlock, self).__init__()
+
+        self.conv0 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            act='relu',
+            name=name + "_branch2a")
+        self.conv1 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            stride=stride,
+            act='relu',
+            name=name + "_branch2b")
+        self.conv2 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels * 4,
+            kernel_size=1,
+            act=None,
+            name=name + "_branch2c")
+
+        if not shortcut:
+            self.short = ConvBNLayer(
+                in_channels=in_channels,
+                out_channels=out_channels * 4,
+                kernel_size=1,
+                stride=1,
+                is_vd_mode=False if if_first else True,
+                name=name + "_branch1")
+
+        self.shortcut = shortcut
+
+    def forward(self, inputs):
+        y = self.conv0(inputs)
+        conv1 = self.conv1(y)
+        conv2 = self.conv2(conv1)
+
+        if self.shortcut:
+            short = inputs
+        else:
+            short = self.short(inputs)
+        y = paddle.add(x=short, y=conv2)
+        y = F.relu(y)
+        return y
+
+
+class BasicBlock(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 shortcut=True,
+                 if_first=False,
+                 name=None):
+        super(BasicBlock, self).__init__()
+        self.stride = stride
+        self.conv0 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            stride=stride,
+            act='relu',
+            name=name + "_branch2a")
+        self.conv1 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            act=None,
+            name=name + "_branch2b")
+
+        if not shortcut:
+            self.short = ConvBNLayer(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+                stride=1,
+                is_vd_mode=False if if_first else True,
+                name=name + "_branch1")
+
+        self.shortcut = shortcut
+
+    def forward(self, inputs):
+        y = self.conv0(inputs)
+        conv1 = self.conv1(y)
+
+        if self.shortcut:
+            short = inputs
+        else:
+            short = self.short(inputs)
+        y = paddle.add(x=short, y=conv1)
+        y = F.relu(y)
+        return y
+
+
+class ResNet_SAST(nn.Layer):
+    def __init__(self, in_channels=3, layers=50, **kwargs):
+        super(ResNet_SAST, self).__init__()
+
+        self.layers = layers
+        supported_layers = [18, 34, 50, 101, 152, 200]
+        assert layers in supported_layers, \
+            "supported layers are {} but input layer is {}".format(
+                supported_layers, layers)
+
+        if layers == 18:
+            depth = [2, 2, 2, 2]
+        elif layers == 34 or layers == 50:
+            # depth = [3, 4, 6, 3]
+            depth = [3, 4, 6, 3, 3]
+        elif layers == 101:
+            depth = [3, 4, 23, 3]
+        elif layers == 152:
+            depth = [3, 8, 36, 3]
+        elif layers == 200:
+            depth = [3, 12, 48, 3]
+        # num_channels = [64, 256, 512,
+        #                 1024] if layers >= 50 else [64, 64, 128, 256]
+        # num_filters = [64, 128, 256, 512]
+        num_channels = [64, 256, 512,
+                        1024, 2048] if layers >= 50 else [64, 64, 128, 256]
+        num_filters = [64, 128, 256, 512, 512]
+
+        self.conv1_1 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=32,
+            kernel_size=3,
+            stride=2,
+            act='relu',
+            name="conv1_1")
+        self.conv1_2 = ConvBNLayer(
+            in_channels=32,
+            out_channels=32,
+            kernel_size=3,
+            stride=1,
+            act='relu',
+            name="conv1_2")
+        self.conv1_3 = ConvBNLayer(
+            in_channels=32,
+            out_channels=64,
+            kernel_size=3,
+            stride=1,
+            act='relu',
+            name="conv1_3")
+        self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+        self.stages = []
+        self.out_channels = [3, 64]
+        if layers >= 50:
+            for block in range(len(depth)):
+                block_list = []
+                shortcut = False
+                for i in range(depth[block]):
+                    if layers in [101, 152] and block == 2:
+                        if i == 0:
+                            conv_name = "res" + str(block + 2) + "a"
+                        else:
+                            conv_name = "res" + str(block + 2) + "b" + str(i)
+                    else:
+                        conv_name = "res" + str(block + 2) + chr(97 + i)
+                    bottleneck_block = self.add_sublayer(
+                        'bb_%d_%d' % (block, i),
+                        BottleneckBlock(
+                            in_channels=num_channels[block]
+                            if i == 0 else num_filters[block] * 4,
+                            out_channels=num_filters[block],
+                            stride=2 if i == 0 and block != 0 else 1,
+                            shortcut=shortcut,
+                            if_first=block == i == 0,
+                            name=conv_name))
+                    shortcut = True
+                    block_list.append(bottleneck_block)
+                self.out_channels.append(num_filters[block] * 4)
+                self.stages.append(nn.Sequential(*block_list))
+        else:
+            for block in range(len(depth)):
+                block_list = []
+                shortcut = False
+                for i in range(depth[block]):
+                    conv_name = "res" + str(block + 2) + chr(97 + i)
+                    basic_block = self.add_sublayer(
+                        'bb_%d_%d' % (block, i),
+                        BasicBlock(
+                            in_channels=num_channels[block]
+                            if i == 0 else num_filters[block],
+                            out_channels=num_filters[block],
+                            stride=2 if i == 0 and block != 0 else 1,
+                            shortcut=shortcut,
+                            if_first=block == i == 0,
+                            name=conv_name))
+                    shortcut = True
+                    block_list.append(basic_block)
+                self.out_channels.append(num_filters[block])
+                self.stages.append(nn.Sequential(*block_list))
+
+    def forward(self, inputs):
+        out = [inputs]
+        y = self.conv1_1(inputs)
+        y = self.conv1_2(y)
+        y = self.conv1_3(y)
+        out.append(y)
+        y = self.pool2d_max(y)
+        for block in self.stages:
+            y = block(y)
+            out.append(y)
+        return out

+ 146 - 0
ocr/ppocr/modeling/backbones/rec_mobilenet_v3.py

@@ -0,0 +1,146 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from paddle import nn
+
+from ppocr.modeling.backbones.det_mobilenet_v3 import ResidualUnit, ConvBNLayer, make_divisible
+
+__all__ = ['MobileNetV3']
+
+
+class MobileNetV3(nn.Layer):
+    def __init__(self,
+                 in_channels=3,
+                 model_name='small',
+                 scale=0.5,
+                 large_stride=None,
+                 small_stride=None,
+                 **kwargs):
+        super(MobileNetV3, self).__init__()
+        if small_stride is None:
+            small_stride = [2, 2, 2, 2]
+        if large_stride is None:
+            large_stride = [1, 2, 2, 2]
+
+        assert isinstance(large_stride, list), "large_stride type must " \
+                                               "be list but got {}".format(type(large_stride))
+        assert isinstance(small_stride, list), "small_stride type must " \
+                                               "be list but got {}".format(type(small_stride))
+        assert len(large_stride) == 4, "large_stride length must be " \
+                                       "4 but got {}".format(len(large_stride))
+        assert len(small_stride) == 4, "small_stride length must be " \
+                                       "4 but got {}".format(len(small_stride))
+        logging.info("MobileNetV3, model_name"+model_name)
+
+        if model_name == "large":
+            cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, False, 'relu', large_stride[0]],
+                [3, 64, 24, False, 'relu', (large_stride[1], 1)],
+                [3, 72, 24, False, 'relu', 1],
+                [5, 72, 40, True, 'relu', (large_stride[2], 1)],
+                [5, 120, 40, True, 'relu', 1],
+                [5, 120, 40, True, 'relu', 1],
+                [3, 240, 80, False, 'hardswish', 1],
+                [3, 200, 80, False, 'hardswish', 1],
+                [3, 184, 80, False, 'hardswish', 1],
+                [3, 184, 80, False, 'hardswish', 1],
+                [3, 480, 112, True, 'hardswish', 1],
+                [3, 672, 112, True, 'hardswish', 1],
+                [5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
+                [5, 960, 160, True, 'hardswish', 1],
+                [5, 960, 160, True, 'hardswish', 1],
+            ]
+            cls_ch_squeeze = 960
+        elif model_name == "small":
+            cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, True, 'relu', (small_stride[0], 1)],
+                [3, 72, 24, False, 'relu', (small_stride[1], 1)],
+                [3, 88, 24, False, 'relu', 1],
+                [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
+                [5, 240, 40, True, 'hardswish', 1],
+                [5, 240, 40, True, 'hardswish', 1],
+                [5, 120, 48, True, 'hardswish', 1],
+                [5, 144, 48, True, 'hardswish', 1],
+                [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
+                [5, 576, 96, True, 'hardswish', 1],
+                [5, 576, 96, True, 'hardswish', 1],
+            ]
+            cls_ch_squeeze = 576
+        else:
+            raise NotImplementedError("mode[" + model_name +
+                                      "_model] is not implemented!")
+
+        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+        assert scale in supported_scale, \
+            "supported scales are {} but input scale is {}".format(supported_scale, scale)
+
+        inplanes = 16
+        # conv1
+        self.conv1 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=make_divisible(inplanes * scale),
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            groups=1,
+            if_act=True,
+            act='hardswish',
+            name='conv1'
+        )
+
+        # blocks
+        # 残差CNN
+        i = 0
+        block_list = []
+        inplanes = make_divisible(inplanes * scale)
+        for (k, exp, c, se, nl, s) in cfg:
+            block_list.append(
+                ResidualUnit(
+                    in_channels=inplanes,
+                    mid_channels=make_divisible(scale * exp),
+                    out_channels=make_divisible(scale * c),
+                    kernel_size=k,
+                    stride=s,
+                    use_se=se,
+                    act=nl,
+                    name='conv' + str(i + 2)))
+            inplanes = make_divisible(scale * c)
+            i += 1
+        self.blocks = nn.Sequential(*block_list)
+
+        # conv2
+        self.conv2 = ConvBNLayer(
+            in_channels=inplanes,
+            out_channels=make_divisible(scale * cls_ch_squeeze),
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            groups=1,
+            if_act=True,
+            act='hardswish',
+            name='conv_last')
+
+        # pool
+        self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+        self.out_channels = make_divisible(scale * cls_ch_squeeze)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.blocks(x)
+        x = self.conv2(x)
+        x = self.pool(x)
+        return x

+ 307 - 0
ocr/ppocr/modeling/backbones/rec_resnet_fpn.py

@@ -0,0 +1,307 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import paddle
+import numpy as np
+
+__all__ = ["ResNetFPN"]
+
+
+class ResNetFPN(nn.Layer):
+    def __init__(self, in_channels=1, layers=50, **kwargs):
+        super(ResNetFPN, self).__init__()
+        supported_layers = {
+            18: {
+                'depth': [2, 2, 2, 2],
+                'block_class': BasicBlock
+            },
+            34: {
+                'depth': [3, 4, 6, 3],
+                'block_class': BasicBlock
+            },
+            50: {
+                'depth': [3, 4, 6, 3],
+                'block_class': BottleneckBlock
+            },
+            101: {
+                'depth': [3, 4, 23, 3],
+                'block_class': BottleneckBlock
+            },
+            152: {
+                'depth': [3, 8, 36, 3],
+                'block_class': BottleneckBlock
+            }
+        }
+        stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
+        num_filters = [64, 128, 256, 512]
+        self.depth = supported_layers[layers]['depth']
+        self.F = []
+        self.conv = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=64,
+            kernel_size=7,
+            stride=2,
+            act="relu",
+            name="conv1")
+        self.block_list = []
+        in_ch = 64
+        if layers >= 50:
+            for block in range(len(self.depth)):
+                for i in range(self.depth[block]):
+                    if layers in [101, 152] and block == 2:
+                        if i == 0:
+                            conv_name = "res" + str(block + 2) + "a"
+                        else:
+                            conv_name = "res" + str(block + 2) + "b" + str(i)
+                    else:
+                        conv_name = "res" + str(block + 2) + chr(97 + i)
+                    block_list = self.add_sublayer(
+                        "bottleneckBlock_{}_{}".format(block, i),
+                        BottleneckBlock(
+                            in_channels=in_ch,
+                            out_channels=num_filters[block],
+                            stride=stride_list[block] if i == 0 else 1,
+                            name=conv_name))
+                    in_ch = num_filters[block] * 4
+                    self.block_list.append(block_list)
+                self.F.append(block_list)
+        else:
+            for block in range(len(self.depth)):
+                for i in range(self.depth[block]):
+                    conv_name = "res" + str(block + 2) + chr(97 + i)
+                    if i == 0 and block != 0:
+                        stride = (2, 1)
+                    else:
+                        stride = (1, 1)
+                    basic_block = self.add_sublayer(
+                        conv_name,
+                        BasicBlock(
+                            in_channels=in_ch,
+                            out_channels=num_filters[block],
+                            stride=stride_list[block] if i == 0 else 1,
+                            is_first=block == i == 0,
+                            name=conv_name))
+                    in_ch = basic_block.out_channels
+                    self.block_list.append(basic_block)
+        out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
+        self.base_block = []
+        self.conv_trans = []
+        self.bn_block = []
+        for i in [-2, -3]:
+            in_channels = out_ch_list[i + 1] + out_ch_list[i]
+
+            self.base_block.append(
+                self.add_sublayer(
+                    "F_{}_base_block_0".format(i),
+                    nn.Conv2D(
+                        in_channels=in_channels,
+                        out_channels=out_ch_list[i],
+                        kernel_size=1,
+                        weight_attr=ParamAttr(trainable=True),
+                        bias_attr=ParamAttr(trainable=True))))
+            self.base_block.append(
+                self.add_sublayer(
+                    "F_{}_base_block_1".format(i),
+                    nn.Conv2D(
+                        in_channels=out_ch_list[i],
+                        out_channels=out_ch_list[i],
+                        kernel_size=3,
+                        padding=1,
+                        weight_attr=ParamAttr(trainable=True),
+                        bias_attr=ParamAttr(trainable=True))))
+            self.base_block.append(
+                self.add_sublayer(
+                    "F_{}_base_block_2".format(i),
+                    nn.BatchNorm(
+                        num_channels=out_ch_list[i],
+                        act="relu",
+                        param_attr=ParamAttr(trainable=True),
+                        bias_attr=ParamAttr(trainable=True))))
+        self.base_block.append(
+            self.add_sublayer(
+                "F_{}_base_block_3".format(i),
+                nn.Conv2D(
+                    in_channels=out_ch_list[i],
+                    out_channels=512,
+                    kernel_size=1,
+                    bias_attr=ParamAttr(trainable=True),
+                    weight_attr=ParamAttr(trainable=True))))
+        self.out_channels = 512
+
+    def __call__(self, x):
+        x = self.conv(x)
+        fpn_list = []
+        F = []
+        for i in range(len(self.depth)):
+            fpn_list.append(np.sum(self.depth[:i + 1]))
+
+        for i, block in enumerate(self.block_list):
+            x = block(x)
+            for number in fpn_list:
+                if i + 1 == number:
+                    F.append(x)
+        base = F[-1]
+
+        j = 0
+        for i, block in enumerate(self.base_block):
+            if i % 3 == 0 and i < 6:
+                j = j + 1
+                b, c, w, h = F[-j - 1].shape
+                if [w, h] == list(base.shape[2:]):
+                    base = base
+                else:
+                    base = self.conv_trans[j - 1](base)
+                    base = self.bn_block[j - 1](base)
+                base = paddle.concat([base, F[-j - 1]], axis=1)
+            base = block(base)
+        return base
+
+
+class ConvBNLayer(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 groups=1,
+                 act=None,
+                 name=None):
+        super(ConvBNLayer, self).__init__()
+        self.conv = nn.Conv2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=2 if stride == (1, 1) else kernel_size,
+            dilation=2 if stride == (1, 1) else 1,
+            stride=stride,
+            padding=(kernel_size - 1) // 2,
+            groups=groups,
+            weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
+            bias_attr=False, )
+
+        if name == "conv1":
+            bn_name = "bn_" + name
+        else:
+            bn_name = "bn" + name[3:]
+        self.bn = nn.BatchNorm(
+            num_channels=out_channels,
+            act=act,
+            param_attr=ParamAttr(name=name + '.output.1.w_0'),
+            bias_attr=ParamAttr(name=name + '.output.1.b_0'),
+            moving_mean_name=bn_name + "_mean",
+            moving_variance_name=bn_name + "_variance")
+
+    def __call__(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+
+class ShortCut(nn.Layer):
+    def __init__(self, in_channels, out_channels, stride, name, is_first=False):
+        super(ShortCut, self).__init__()
+        self.use_conv = True
+
+        if in_channels != out_channels or stride != 1 or is_first == True:
+            if stride == (1, 1):
+                self.conv = ConvBNLayer(
+                    in_channels, out_channels, 1, 1, name=name)
+            else:  # stride==(2,2)
+                self.conv = ConvBNLayer(
+                    in_channels, out_channels, 1, stride, name=name)
+        else:
+            self.use_conv = False
+
+    def forward(self, x):
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+
+class BottleneckBlock(nn.Layer):
+    def __init__(self, in_channels, out_channels, stride, name):
+        super(BottleneckBlock, self).__init__()
+        self.conv0 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            act='relu',
+            name=name + "_branch2a")
+        self.conv1 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            stride=stride,
+            act='relu',
+            name=name + "_branch2b")
+
+        self.conv2 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels * 4,
+            kernel_size=1,
+            act=None,
+            name=name + "_branch2c")
+
+        self.short = ShortCut(
+            in_channels=in_channels,
+            out_channels=out_channels * 4,
+            stride=stride,
+            is_first=False,
+            name=name + "_branch1")
+        self.out_channels = out_channels * 4
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        y = self.conv2(y)
+        y = y + self.short(x)
+        y = F.relu(y)
+        return y
+
+
+class BasicBlock(nn.Layer):
+    def __init__(self, in_channels, out_channels, stride, name, is_first):
+        super(BasicBlock, self).__init__()
+        self.conv0 = ConvBNLayer(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            act='relu',
+            stride=stride,
+            name=name + "_branch2a")
+        self.conv1 = ConvBNLayer(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=3,
+            act=None,
+            name=name + "_branch2b")
+        self.short = ShortCut(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            stride=stride,
+            is_first=is_first,
+            name=name + "_branch1")
+        self.out_channels = out_channels
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        y = y + self.short(x)
+        return F.relu(y)

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor