Development of the ocr part of AOI
Samo Penic
2019-06-11 5d557801d61beb4970ffc4c62ba81cd0cd76db68
commit | author | age
e555c0 1 from pyzbar.pyzbar import decode
0d97e9 2 from .sid_process import getSID
e555c0 3 import cv2
SP 4 import numpy as np
69abed 5 import os
0d97e9 6 import pkg_resources
5d5578 7 from .rotation_wrapper import get_scan_angle
0d97e9 8
d88ce4 9 markerfile = '/template-sq.png'  # always use slash
0d97e9 10 markerfilename = pkg_resources.resource_filename(__name__, markerfile)
SP 11
e555c0 12
SP 13
511c2e 14 class Paper:
69abed 15     def __init__(self, filename=None, sid_classifier=None, settings=None, output_path="/tmp"):
511c2e 16         self.filename = filename
69abed 17         self.output_path=output_path
511c2e 18         self.invalid = None
SP 19         self.QRData = None
e0996e 20         self.settings = {"answer_threshold": 0.25} if settings is None else settings
511c2e 21         self.errors = []
SP 22         self.warnings = []
e0996e 23         self.sid = None
762a5e 24         self.sid_classifier = sid_classifier
511c2e 25         if filename is not None:
SP 26             self.loadImage(filename)
27             self.runOcr()
e555c0 28
511c2e 29     def loadImage(self, filename, rgbchannel=0):
SP 30         self.img = cv2.imread(filename, rgbchannel)
31         if self.img is None:
32             self.errors.append("File could not be loaded!")
33             self.invalid = True
34             return
35         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 36
0d97e9 37     def saveImage(self, filename="/tmp/debug_image.png"):
511c2e 38         cv2.imwrite(filename, self.img)
e555c0 39
511c2e 40     def runOcr(self):
SP 41         if self.invalid == True:
42             return
43         self.decodeQRandRotate()
44         self.imgTreshold()
d88ce4 45         cv2.imwrite('/tmp/debug_threshold.png', self.bwimg)
511c2e 46         skewAngle = 0
SP 47         #         try:
48         #             skewAngle=self.getSkewAngle()
49         #         except:
50         #             self.errors.append("Could not determine skew angle!")
51         #         self.rotateAngle(skewAngle)
e555c0 52
511c2e 53         self.generateAnswerMatrix()
e555c0 54
511c2e 55         self.saveImage()
e555c0 56
511c2e 57     def decodeQRandRotate(self):
SP 58         if self.invalid == True:
59             return
60         blur = cv2.blur(self.img, (3, 3))
61         d = decode(blur)
62         self.img = blur
63         if len(d) == 0:
64             self.errors.append("QR code could not be found!")
65             self.data = None
66             self.invalid = True
67             return
e0996e 68         if(len(d)>1): #if there are multiple codes, get first ean or qr code available.
SP 69             for dd in d:
70                 if(dd.type=="EAN13" or dd.type=="QR"):
71                     d[0]=dd
72                     break
511c2e 73         self.QRDecode = d
SP 74         self.QRData = d[0].data
75         xpos = d[0].rect.left
76         ypos = d[0].rect.top
77         # check if image is rotated wrongly
82ec6d 78         if xpos > self.imgHeight / 2.0 and ypos > self.imgWidth / 2.0:
511c2e 79             self.rotateAngle(180)
e555c0 80
5d5578 81         #small rotation check
SP 82         angle=get_scan_angle(self.filename)
83         if angle>0.1 or angle<-0.1:
84             print("Rotating for angle of {}.".format(angle))
85             self.rotateAngle(-angle)
86
511c2e 87     def rotateAngle(self, angle=0):
e0996e 88         # rot_mat = cv2.getRotationMatrix2D(
82ec6d 89         #    (self.imgHeight / 2, self.imgWidth / 2), angle, 1.0
e0996e 90         # )
511c2e 91         rot_mat = cv2.getRotationMatrix2D(
e0996e 92             (self.imgWidth / 2, self.imgHeight / 2), angle, 1.0
511c2e 93         )
SP 94         result = cv2.warpAffine(
95             self.img,
96             rot_mat,
82ec6d 97             (self.imgWidth, self.imgHeight),
511c2e 98             flags=cv2.INTER_CUBIC,
SP 99             borderMode=cv2.BORDER_CONSTANT,
100             borderValue=(255, 255, 255),
101         )
e555c0 102
511c2e 103         self.img = result
SP 104         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 105
511c2e 106         # todo, make better tresholding
SP 107     def imgTreshold(self):
108         (self.thresh, self.bwimg) = cv2.threshold(
9c222b 109             self.img, 128, 255,
SP 110             cv2.THRESH_BINARY | cv2.THRESH_OTSU
511c2e 111         )
e555c0 112
511c2e 113     def getSkewAngle(self):
SP 114         neg = 255 - self.bwimg  # get negative image
0d97e9 115         cv2.imwrite("/tmp/debug_1.png", neg)
e555c0 116
511c2e 117         angle_counter = 0  # number of angles
SP 118         angle = 0.0  # collects sum of angles
119         cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
e555c0 120
511c2e 121         # get all the Hough lines
SP 122         for line in cv2.HoughLinesP(neg, 1, np.pi / 180, 325):
123             x1, y1, x2, y2 = line[0]
124             cv2.line(cimg, (x1, y1), (x2, y2), (0, 0, 255), 2)
125             # calculate the angle (in radians)
126             this_angle = np.arctan2(y2 - y1, x2 - x1)
127             if this_angle and abs(this_angle) <= 10:
128                 # filtered zero degree and outliers
129                 angle += this_angle
130                 angle_counter += 1
e555c0 131
511c2e 132                 # the skew is calculated of the mean of the total angles, #try block helps with division by zero.
SP 133         try:
134             skew = np.rad2deg(
135                 angle / angle_counter
136             )  # the 1.2 factor is just experimental....
137         except:
138             skew = 0
e555c0 139
0d97e9 140         cv2.imwrite("/tmp/debug_2.png", cimg)
511c2e 141         return skew
e555c0 142
511c2e 143     def locateUpMarkers(self, threshold=0.85, height=200):
0d97e9 144         template = cv2.imread(markerfilename, 0)
511c2e 145         w, h = template.shape[::-1]
d88ce4 146         crop_img = self.bwimg[0:height, :]
511c2e 147         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
SP 148         loc = np.where(res >= threshold)
149         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
150         # remove false matching of the squares in qr code
151         loc_filtered_x = []
152         loc_filtered_y = []
153         if len(loc[0]) == 0:
154             min_y = -1
155         else:
156             min_y = np.min(loc[0])
157             for pt in zip(*loc[::-1]):
158                 if pt[1] < min_y + 20:
159                     loc_filtered_y.append(pt[1])
160                     loc_filtered_x.append(pt[0])
161                     # order by x coordinate
162             loc_filtered_x, loc_filtered_y = zip(
163                 *sorted(zip(loc_filtered_x, loc_filtered_y))
164             )
02e0f7 165             # loc=[loc_filtered_y,loc_filtered_x]
SP 166             # remove duplicates
511c2e 167             a = np.diff(loc_filtered_x) > 40
SP 168             a = np.append(a, True)
169             loc_filtered_x = np.array(loc_filtered_x)
170             loc_filtered_y = np.array(loc_filtered_y)
171             loc = [loc_filtered_y[a], loc_filtered_x[a]]
172             for pt in zip(*loc[::-1]):
173                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 174
0d97e9 175         cv2.imwrite("/tmp/debug_3.png", cimg)
e555c0 176
511c2e 177         self.xMarkerLocations = loc
SP 178         return loc
e555c0 179
511c2e 180     def locateRightMarkers(self, threshold=0.85, width=200):
0d97e9 181         template = cv2.imread(markerfilename, 0)
511c2e 182         w, h = template.shape[::-1]
d88ce4 183         crop_img = self.bwimg[:, -width:]
SP 184         cv2.imwrite('/tmp/debug_right.png', crop_img)
511c2e 185         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
SP 186         loc = np.where(res >= threshold)
187         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
188         # remove false matching of the squares in qr code
189         loc_filtered_x = []
190         loc_filtered_y = []
191         if len(loc[1]) == 0:
192             min_x = -1
193         else:
194             max_x = np.max(loc[1])
195             for pt in zip(*loc[::-1]):
196                 if pt[1] > max_x - 20:
197                     loc_filtered_y.append(pt[1])
198                     loc_filtered_x.append(pt[0])
199                     # order by y coordinate
d88ce4 200             try:
SP 201                 loc_filtered_y, loc_filtered_x = zip(
202                     *sorted(zip(loc_filtered_y, loc_filtered_x))
203                 )
204             except:
205                 self.yMarkerLocations=[np.array([1,1]),np.array([1,2])]
206                 return self.yMarkerLocations
511c2e 207             # loc=[loc_filtered_y,loc_filtered_x]
SP 208             # remove duplicates
209             a = np.diff(loc_filtered_y) > 40
210             a = np.append(a, True)
211             loc_filtered_x = np.array(loc_filtered_x)
212             loc_filtered_y = np.array(loc_filtered_y)
213             loc = [loc_filtered_y[a], loc_filtered_x[a]]
214             for pt in zip(*loc[::-1]):
215                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 216
0d97e9 217         cv2.imwrite("/tmp/debug_4.png", cimg)
e555c0 218
511c2e 219         self.yMarkerLocations = [loc[0], loc[1] + self.imgWidth - width]
SP 220         return self.yMarkerLocations
e555c0 221
511c2e 222     def generateAnswerMatrix(self):
SP 223         self.locateUpMarkers()
224         self.locateRightMarkers()
e555c0 225
bcebb8 226         roixoff = 4
SP 227         roiyoff = 0
228         roiwidth = 55
511c2e 229         roiheight = roiwidth
SP 230         totpx = roiwidth * roiheight
bcebb8 231         cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
511c2e 232         self.answerMatrix = []
SP 233         for y in self.yMarkerLocations[0]:
234             oneline = []
235             for x in self.xMarkerLocations[1]:
236                 roi = self.bwimg[
237                     y - roiyoff : y + int(roiheight - roiyoff),
238                     x - roixoff : x + int(roiwidth - roixoff),
239                 ]
240                 # cv2.imwrite('ans_x'+str(x)+'_y_'+str(y)+'.png',roi)
241                 black = totpx - cv2.countNonZero(roi)
242                 oneline.append(black / totpx)
bcebb8 243                 cv2.rectangle(cimg, (x - roixoff,y - roiyoff), (x + int(roiwidth - roixoff),y + int(roiheight - roiyoff)), (0, 255, 255), 2)
SP 244             cv2.imwrite('/tmp/debug_answers.png',cimg)
511c2e 245             self.answerMatrix.append(oneline)
9efc18 246
SP 247     def get_enhanced_sid(self):
02e0f7 248         if self.sid_classifier is None:
SP 249             return "x"
762a5e 250         if self.settings is not None:
e0996e 251             sid_mask = self.settings.get("sid_mask", None)
SP 252         es, err, warn = getSID(
02e0f7 253             self.img[
d5c694 254                 int(0.04 * self.imgHeight) : int(0.095 * self.imgHeight),
9c222b 255                 int(0.65 * self.imgWidth) : int(0.95 * self.imgWidth),
02e0f7 256             ],
SP 257             self.sid_classifier,
e0996e 258             sid_mask,
02e0f7 259         )
5cb7c1 260         [self.errors.append(e) for e in err]
SP 261         [self.warnings.append(w) for w in warn]
02e0f7 262         return es
0436f6 263
SP 264     def get_code_data(self):
cf921b 265         if self.QRData is None:
SP 266             self.errors.append("Could not read QR or EAN code! Not an exam?")
e0996e 267             retval = {
SP 268                 "exam_id": None,
269                 "page_no": None,
270                 "paper_id": None,
271                 "faculty_id": None,
272                 "sid": None,
0436f6 273             }
e0996e 274             return retval
SP 275         qrdata = bytes.decode(self.QRData, "utf8")
276         if self.QRDecode[0].type == "EAN13":
277             return {
278                 "exam_id": int(qrdata[0:7]),
69abed 279                 "page_no": int(qrdata[7])+1,
e0996e 280                 "paper_id": int(qrdata[-5:-1]),
SP 281                 "faculty_id": None,
282                 "sid": None,
283             }
284         else:
285             data = qrdata.split(",")
286             retval = {
287                 "exam_id": int(data[1]),
9c222b 288                 "page_no": int(data[3]),
e0996e 289                 "paper_id": int(data[2]),
SP 290                 "faculty_id": int(data[0]),
d88ce4 291                 "sid": None
e0996e 292             }
SP 293             if len(data) > 4:
294                 retval["sid"] = data[4]
0436f6 295
SP 296             return retval
297
298     def get_paper_ocr_data(self):
e0996e 299         data = self.get_code_data()
8dad65 300         if self.QRData is None:
SP 301             return None
69abed 302         data["qr"] = bytes.decode(self.QRData, 'utf8')
e0996e 303         data["errors"] = self.errors
SP 304         data["warnings"] = self.warnings
305         data["up_position"] = (
9c222b 306             list(self.xMarkerLocations[0] / self.imgWidth),
SP 307             list(self.xMarkerLocations[1] / self.imgHeight),
e0996e 308         )
SP 309         data["right_position"] = (
9c222b 310             list(self.yMarkerLocations[0] / self.imgWidth),
e0996e 311             list(self.yMarkerLocations[1] / self.imgHeight),
SP 312         )
313         data["ans_matrix"] = (
314             (np.array(self.answerMatrix) > self.settings["answer_threshold"]) * 1
315         ).tolist()
9c222b 316         if data["sid"] is None and data["page_no"] == 1:
e0996e 317             data["sid"] = self.get_enhanced_sid()
69abed 318         output_filename=os.path.join(self.output_path, '.'.join(self.filename.split('/')[-1].split('.')[:-1])+".png")
SP 319         cv2.imwrite(output_filename, self.img)
320         data['output_filename']=output_filename
bcebb8 321         print(np.array(self.answerMatrix))
0436f6 322         return data