From 5d557801d61beb4970ffc4c62ba81cd0cd76db68 Mon Sep 17 00:00:00 2001
From: Samo Penic <samo.penic@gmail.com>
Date: Tue, 11 Jun 2019 17:46:17 +0000
Subject: [PATCH] Angle of page rotation can be determined with c wrapper function.

---
 aoi_ocr/Ocr.py |   65 ++++++++++++++++++++++----------
 1 files changed, 44 insertions(+), 21 deletions(-)

diff --git a/aoi_ocr/Ocr.py b/aoi_ocr/Ocr.py
index d743ecc..842a237 100644
--- a/aoi_ocr/Ocr.py
+++ b/aoi_ocr/Ocr.py
@@ -2,17 +2,19 @@
 from .sid_process import getSID
 import cv2
 import numpy as np
-
+import os
 import pkg_resources
+from .rotation_wrapper import get_scan_angle
 
-markerfile = '/template.png'  # always use slash
+markerfile = '/template-sq.png'  # always use slash
 markerfilename = pkg_resources.resource_filename(__name__, markerfile)
 
 
 
 class Paper:
-    def __init__(self, filename=None, sid_classifier=None, settings=None):
+    def __init__(self, filename=None, sid_classifier=None, settings=None, output_path="/tmp"):
         self.filename = filename
+        self.output_path=output_path
         self.invalid = None
         self.QRData = None
         self.settings = {"answer_threshold": 0.25} if settings is None else settings
@@ -40,6 +42,7 @@
             return
         self.decodeQRandRotate()
         self.imgTreshold()
+        cv2.imwrite('/tmp/debug_threshold.png', self.bwimg)
         skewAngle = 0
         # 		try:
         # 			skewAngle=self.getSkewAngle()
@@ -75,6 +78,12 @@
         if xpos > self.imgHeight / 2.0 and ypos > self.imgWidth / 2.0:
             self.rotateAngle(180)
 
+        #small rotation check
+        angle=get_scan_angle(self.filename)
+        if angle>0.1 or angle<-0.1:
+            print("Rotating for angle of {}.".format(angle))
+            self.rotateAngle(-angle)
+
     def rotateAngle(self, angle=0):
         # rot_mat = cv2.getRotationMatrix2D(
         #    (self.imgHeight / 2, self.imgWidth / 2), angle, 1.0
@@ -95,10 +104,10 @@
         self.imgHeight, self.imgWidth = self.img.shape[0:2]
 
         # todo, make better tresholding
-
     def imgTreshold(self):
         (self.thresh, self.bwimg) = cv2.threshold(
-            self.img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU
+            self.img, 128, 255,
+            cv2.THRESH_BINARY | cv2.THRESH_OTSU
         )
 
     def getSkewAngle(self):
@@ -134,7 +143,7 @@
     def locateUpMarkers(self, threshold=0.85, height=200):
         template = cv2.imread(markerfilename, 0)
         w, h = template.shape[::-1]
-        crop_img = self.img[0:height, :]
+        crop_img = self.bwimg[0:height, :]
         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
         loc = np.where(res >= threshold)
         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
@@ -171,7 +180,8 @@
     def locateRightMarkers(self, threshold=0.85, width=200):
         template = cv2.imread(markerfilename, 0)
         w, h = template.shape[::-1]
-        crop_img = self.img[:, -width:]
+        crop_img = self.bwimg[:, -width:]
+        cv2.imwrite('/tmp/debug_right.png', crop_img)
         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
         loc = np.where(res >= threshold)
         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
@@ -187,9 +197,13 @@
                     loc_filtered_y.append(pt[1])
                     loc_filtered_x.append(pt[0])
                     # order by y coordinate
-            loc_filtered_y, loc_filtered_x = zip(
-                *sorted(zip(loc_filtered_y, loc_filtered_x))
-            )
+            try:
+                loc_filtered_y, loc_filtered_x = zip(
+                    *sorted(zip(loc_filtered_y, loc_filtered_x))
+                )
+            except:
+                self.yMarkerLocations=[np.array([1,1]),np.array([1,2])]
+                return self.yMarkerLocations
             # loc=[loc_filtered_y,loc_filtered_x]
             # remove duplicates
             a = np.diff(loc_filtered_y) > 40
@@ -209,12 +223,12 @@
         self.locateUpMarkers()
         self.locateRightMarkers()
 
-        roixoff = 10
-        roiyoff = 5
-        roiwidth = 50
+        roixoff = 4
+        roiyoff = 0
+        roiwidth = 55
         roiheight = roiwidth
         totpx = roiwidth * roiheight
-
+        cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
         self.answerMatrix = []
         for y in self.yMarkerLocations[0]:
             oneline = []
@@ -226,6 +240,8 @@
                 # cv2.imwrite('ans_x'+str(x)+'_y_'+str(y)+'.png',roi)
                 black = totpx - cv2.countNonZero(roi)
                 oneline.append(black / totpx)
+                cv2.rectangle(cimg, (x - roixoff,y - roiyoff), (x + int(roiwidth - roixoff),y + int(roiheight - roiyoff)), (0, 255, 255), 2)
+            cv2.imwrite('/tmp/debug_answers.png',cimg)
             self.answerMatrix.append(oneline)
 
     def get_enhanced_sid(self):
@@ -236,7 +252,7 @@
         es, err, warn = getSID(
             self.img[
                 int(0.04 * self.imgHeight) : int(0.095 * self.imgHeight),
-                int(0.7 * self.imgWidth) : int(0.99 * self.imgWidth),
+                int(0.65 * self.imgWidth) : int(0.95 * self.imgWidth),
             ],
             self.sid_classifier,
             sid_mask,
@@ -260,7 +276,7 @@
         if self.QRDecode[0].type == "EAN13":
             return {
                 "exam_id": int(qrdata[0:7]),
-                "page_no": int(qrdata[7]),
+                "page_no": int(qrdata[7])+1,
                 "paper_id": int(qrdata[-5:-1]),
                 "faculty_id": None,
                 "sid": None,
@@ -272,6 +288,7 @@
                 "page_no": int(data[3]),
                 "paper_id": int(data[2]),
                 "faculty_id": int(data[0]),
+                "sid": None
             }
             if len(data) > 4:
                 retval["sid"] = data[4]
@@ -280,20 +297,26 @@
 
     def get_paper_ocr_data(self):
         data = self.get_code_data()
-        data["qr"] = self.QRData
+        if self.QRData is None:
+            return None
+        data["qr"] = bytes.decode(self.QRData, 'utf8')
         data["errors"] = self.errors
         data["warnings"] = self.warnings
         data["up_position"] = (
-            list(self.xMarkerLocations[1] / self.imgWidth),
-            list(self.yMarkerLocations[1] / self.imgHeight),
+            list(self.xMarkerLocations[0] / self.imgWidth),
+            list(self.xMarkerLocations[1] / self.imgHeight),
         )
         data["right_position"] = (
-            list(self.xMarkerLocations[1] / self.imgWidth),
+            list(self.yMarkerLocations[0] / self.imgWidth),
             list(self.yMarkerLocations[1] / self.imgHeight),
         )
         data["ans_matrix"] = (
             (np.array(self.answerMatrix) > self.settings["answer_threshold"]) * 1
         ).tolist()
-        if data["sid"] is None and data["page_no"] == 0:
+        if data["sid"] is None and data["page_no"] == 1:
             data["sid"] = self.get_enhanced_sid()
+        output_filename=os.path.join(self.output_path, '.'.join(self.filename.split('/')[-1].split('.')[:-1])+".png")
+        cv2.imwrite(output_filename, self.img)
+        data['output_filename']=output_filename
+        print(np.array(self.answerMatrix))
         return data

--
Gitblit v1.9.3