diff --git a/pdf_table_extraction_and_ocr.org b/pdf_table_extraction_and_ocr.org index 61dfd67..327e867 100644 --- a/pdf_table_extraction_and_ocr.org +++ b/pdf_table_extraction_and_ocr.org @@ -917,8 +917,25 @@ for image, tables in results: #+BEGIN_SRC python :tangle table_ocr/extract_cells/__init__.py import cv2 +import os <> + +def main(f): + results = [] + directory, filename = os.path.split(f) + table = cv2.imread(f, cv2.IMREAD_GRAYSCALE) + rows = extract_cell_images_from_table(table) + cell_img_dir = os.path.join(directory, "cells") + os.makedirs(cell_img_dir, exist_ok=True) + paths = [] + for i, row in enumerate(rows): + for j, cell in enumerate(row): + cell_filename = "{:03d}-{:03d}.png".format(i, j) + path = os.path.join(cell_img_dir, cell_filename) + cv2.imwrite(path, cell) + paths.append(path) + return paths #+END_SRC **** table_ocr/extract_cells/__main__.py @@ -936,32 +953,12 @@ Prints to stdout the lexicographically sorted list of filenames of the extracted cells. #+BEGIN_SRC python :tangle table_ocr/extract_cells/__main__.py :results none -import os import sys -import cv2 - -from table_ocr.extract_cells import extract_cell_images_from_table +from table_ocr.extract_cells import main -def main(f): - results = [] - directory, filename = os.path.split(f) - table = cv2.imread(f, cv2.IMREAD_GRAYSCALE) - rows = extract_cell_images_from_table(table) - cell_img_dir = os.path.join(directory, "cells") - os.makedirs(cell_img_dir, exist_ok=True) - for i, row in enumerate(rows): - for j, cell in enumerate(row): - cell_filename = "{:03d}-{:03d}.png".format(i, j) - path = os.path.join(cell_img_dir, cell_filename) - cv2.imwrite(path, cell) - print(path) - - -<> - -if __name__ == "__main__": - main(sys.argv[1]) +paths = main(sys.argv[1]) +print("\n".join(paths)) #+END_SRC *** table_ocr/ocr_image/ diff --git a/table_ocr/extract_cells/__init__.py b/table_ocr/extract_cells/__init__.py index 4fed823..8e39097 100644 --- a/table_ocr/extract_cells/__init__.py +++ b/table_ocr/extract_cells/__init__.py @@ -1,4 +1,5 @@ import cv2 +import os def extract_cell_images_from_table(image): BLUR_KERNEL_SIZE = (17, 17) @@ -95,3 +96,19 @@ def extract_cell_images_from_table(image): cell_images_row.append(image[y:y+h, x:x+w]) cell_images_rows.append(cell_images_row) return cell_images_rows + +def main(f): + results = [] + directory, filename = os.path.split(f) + table = cv2.imread(f, cv2.IMREAD_GRAYSCALE) + rows = extract_cell_images_from_table(table) + cell_img_dir = os.path.join(directory, "cells") + os.makedirs(cell_img_dir, exist_ok=True) + paths = [] + for i, row in enumerate(rows): + for j, cell in enumerate(row): + cell_filename = "{:03d}-{:03d}.png".format(i, j) + path = os.path.join(cell_img_dir, cell_filename) + cv2.imwrite(path, cell) + paths.append(path) + return paths diff --git a/table_ocr/extract_cells/__main__.py b/table_ocr/extract_cells/__main__.py index f1daee2..8d1d222 100644 --- a/table_ocr/extract_cells/__main__.py +++ b/table_ocr/extract_cells/__main__.py @@ -1,120 +1,6 @@ -import os import sys -import cv2 +from table_ocr.extract_cells import main -from table_ocr.extract_cells import extract_cell_images_from_table - -def main(f): - results = [] - directory, filename = os.path.split(f) - table = cv2.imread(f, cv2.IMREAD_GRAYSCALE) - rows = extract_cell_images_from_table(table) - cell_img_dir = os.path.join(directory, "cells") - os.makedirs(cell_img_dir, exist_ok=True) - for i, row in enumerate(rows): - for j, cell in enumerate(row): - cell_filename = "{:03d}-{:03d}.png".format(i, j) - path = os.path.join(cell_img_dir, cell_filename) - cv2.imwrite(path, cell) - print(path) - - -def extract_cell_images_from_table(image): - BLUR_KERNEL_SIZE = (17, 17) - STD_DEV_X_DIRECTION = 0 - STD_DEV_Y_DIRECTION = 0 - blurred = cv2.GaussianBlur(image, BLUR_KERNEL_SIZE, STD_DEV_X_DIRECTION, STD_DEV_Y_DIRECTION) - MAX_COLOR_VAL = 255 - BLOCK_SIZE = 15 - SUBTRACT_FROM_MEAN = -2 - - img_bin = cv2.adaptiveThreshold( - ~blurred, - MAX_COLOR_VAL, - cv2.ADAPTIVE_THRESH_MEAN_C, - cv2.THRESH_BINARY, - BLOCK_SIZE, - SUBTRACT_FROM_MEAN, - ) - vertical = horizontal = img_bin.copy() - SCALE = 5 - image_width, image_height = horizontal.shape - horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (int(image_width / SCALE), 1)) - horizontally_opened = cv2.morphologyEx(img_bin, cv2.MORPH_OPEN, horizontal_kernel) - vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, int(image_height / SCALE))) - vertically_opened = cv2.morphologyEx(img_bin, cv2.MORPH_OPEN, vertical_kernel) - - horizontally_dilated = cv2.dilate(horizontally_opened, cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))) - vertically_dilated = cv2.dilate(vertically_opened, cv2.getStructuringElement(cv2.MORPH_RECT, (1, 60))) - - mask = horizontally_dilated + vertically_dilated - contours, heirarchy = cv2.findContours( - mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE, - ) - - perimeter_lengths = [cv2.arcLength(c, True) for c in contours] - epsilons = [0.05 * p for p in perimeter_lengths] - approx_polys = [cv2.approxPolyDP(c, e, True) for c, e in zip(contours, epsilons)] - - # Filter out contours that aren't rectangular. Those that aren't rectangular - # are probably noise. - approx_rects = [p for p in approx_polys if len(p) == 4] - bounding_rects = [cv2.boundingRect(a) for a in approx_polys] - - # Filter out rectangles that are too narrow or too short. - MIN_RECT_WIDTH = 40 - MIN_RECT_HEIGHT = 10 - bounding_rects = [ - r for r in bounding_rects if MIN_RECT_WIDTH < r[2] and MIN_RECT_HEIGHT < r[3] - ] - - # The largest bounding rectangle is assumed to be the entire table. - # Remove it from the list. We don't want to accidentally try to OCR - # the entire table. - largest_rect = max(bounding_rects, key=lambda r: r[2] * r[3]) - bounding_rects = [b for b in bounding_rects if b is not largest_rect] - - cells = [c for c in bounding_rects] - def cell_in_same_row(c1, c2): - c1_center = c1[1] + c1[3] - c1[3] / 2 - c2_bottom = c2[1] + c2[3] - c2_top = c2[1] - return c2_top < c1_center < c2_bottom - - orig_cells = [c for c in cells] - rows = [] - while cells: - first = cells[0] - rest = cells[1:] - cells_in_same_row = sorted( - [ - c for c in rest - if cell_in_same_row(c, first) - ], - key=lambda c: c[0] - ) - - row_cells = sorted([first] + cells_in_same_row, key=lambda c: c[0]) - rows.append(row_cells) - cells = [ - c for c in rest - if not cell_in_same_row(c, first) - ] - - # Sort rows by average height of their center. - def avg_height_of_center(row): - centers = [y + h - h / 2 for x, y, w, h in row] - return sum(centers) / len(centers) - - rows.sort(key=avg_height_of_center) - cell_images_rows = [] - for row in rows: - cell_images_row = [] - for x, y, w, h in row: - cell_images_row.append(image[y:y+h, x:x+w]) - cell_images_rows.append(cell_images_row) - return cell_images_rows - -if __name__ == "__main__": - main(sys.argv[1]) +paths = main(sys.argv[1]) +print("\n".join(paths))