Development of the ocr part of AOI
Samo Penic
2018-11-21 5460bf601a854c842342a740df0f6d36ad785bbc
commit | author | age
e555c0 1 from pyzbar.pyzbar import decode
0d97e9 2 from .sid_process import getSID
e555c0 3 import cv2
SP 4 import numpy as np
69abed 5 import os
0d97e9 6 import pkg_resources
SP 7
d88ce4 8 markerfile = '/template-sq.png'  # always use slash
0d97e9 9 markerfilename = pkg_resources.resource_filename(__name__, markerfile)
SP 10
e555c0 11
SP 12
511c2e 13 class Paper:
69abed 14     def __init__(self, filename=None, sid_classifier=None, settings=None, output_path="/tmp"):
511c2e 15         self.filename = filename
69abed 16         self.output_path=output_path
511c2e 17         self.invalid = None
SP 18         self.QRData = None
e0996e 19         self.settings = {"answer_threshold": 0.25} if settings is None else settings
511c2e 20         self.errors = []
SP 21         self.warnings = []
e0996e 22         self.sid = None
762a5e 23         self.sid_classifier = sid_classifier
511c2e 24         if filename is not None:
SP 25             self.loadImage(filename)
26             self.runOcr()
e555c0 27
511c2e 28     def loadImage(self, filename, rgbchannel=0):
SP 29         self.img = cv2.imread(filename, rgbchannel)
30         if self.img is None:
31             self.errors.append("File could not be loaded!")
32             self.invalid = True
33             return
34         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 35
0d97e9 36     def saveImage(self, filename="/tmp/debug_image.png"):
511c2e 37         cv2.imwrite(filename, self.img)
e555c0 38
511c2e 39     def runOcr(self):
SP 40         if self.invalid == True:
41             return
42         self.decodeQRandRotate()
43         self.imgTreshold()
d88ce4 44         cv2.imwrite('/tmp/debug_threshold.png', self.bwimg)
511c2e 45         skewAngle = 0
SP 46         #         try:
47         #             skewAngle=self.getSkewAngle()
48         #         except:
49         #             self.errors.append("Could not determine skew angle!")
50         #         self.rotateAngle(skewAngle)
e555c0 51
511c2e 52         self.generateAnswerMatrix()
e555c0 53
511c2e 54         self.saveImage()
e555c0 55
511c2e 56     def decodeQRandRotate(self):
SP 57         if self.invalid == True:
58             return
59         blur = cv2.blur(self.img, (3, 3))
60         d = decode(blur)
61         self.img = blur
62         if len(d) == 0:
63             self.errors.append("QR code could not be found!")
64             self.data = None
65             self.invalid = True
66             return
e0996e 67         if(len(d)>1): #if there are multiple codes, get first ean or qr code available.
SP 68             for dd in d:
69                 if(dd.type=="EAN13" or dd.type=="QR"):
70                     d[0]=dd
71                     break
511c2e 72         self.QRDecode = d
SP 73         self.QRData = d[0].data
74         xpos = d[0].rect.left
75         ypos = d[0].rect.top
76         # check if image is rotated wrongly
82ec6d 77         if xpos > self.imgHeight / 2.0 and ypos > self.imgWidth / 2.0:
511c2e 78             self.rotateAngle(180)
e555c0 79
511c2e 80     def rotateAngle(self, angle=0):
e0996e 81         # rot_mat = cv2.getRotationMatrix2D(
82ec6d 82         #    (self.imgHeight / 2, self.imgWidth / 2), angle, 1.0
e0996e 83         # )
511c2e 84         rot_mat = cv2.getRotationMatrix2D(
e0996e 85             (self.imgWidth / 2, self.imgHeight / 2), angle, 1.0
511c2e 86         )
SP 87         result = cv2.warpAffine(
88             self.img,
89             rot_mat,
82ec6d 90             (self.imgWidth, self.imgHeight),
511c2e 91             flags=cv2.INTER_CUBIC,
SP 92             borderMode=cv2.BORDER_CONSTANT,
93             borderValue=(255, 255, 255),
94         )
e555c0 95
511c2e 96         self.img = result
SP 97         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 98
511c2e 99         # todo, make better tresholding
SP 100     def imgTreshold(self):
101         (self.thresh, self.bwimg) = cv2.threshold(
9c222b 102             self.img, 128, 255,
SP 103             cv2.THRESH_BINARY | cv2.THRESH_OTSU
511c2e 104         )
e555c0 105
511c2e 106     def getSkewAngle(self):
SP 107         neg = 255 - self.bwimg  # get negative image
0d97e9 108         cv2.imwrite("/tmp/debug_1.png", neg)
e555c0 109
511c2e 110         angle_counter = 0  # number of angles
SP 111         angle = 0.0  # collects sum of angles
112         cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
e555c0 113
511c2e 114         # get all the Hough lines
SP 115         for line in cv2.HoughLinesP(neg, 1, np.pi / 180, 325):
116             x1, y1, x2, y2 = line[0]
117             cv2.line(cimg, (x1, y1), (x2, y2), (0, 0, 255), 2)
118             # calculate the angle (in radians)
119             this_angle = np.arctan2(y2 - y1, x2 - x1)
120             if this_angle and abs(this_angle) <= 10:
121                 # filtered zero degree and outliers
122                 angle += this_angle
123                 angle_counter += 1
e555c0 124
511c2e 125                 # the skew is calculated of the mean of the total angles, #try block helps with division by zero.
SP 126         try:
127             skew = np.rad2deg(
128                 angle / angle_counter
129             )  # the 1.2 factor is just experimental....
130         except:
131             skew = 0
e555c0 132
0d97e9 133         cv2.imwrite("/tmp/debug_2.png", cimg)
511c2e 134         return skew
e555c0 135
511c2e 136     def locateUpMarkers(self, threshold=0.85, height=200):
0d97e9 137         template = cv2.imread(markerfilename, 0)
511c2e 138         w, h = template.shape[::-1]
d88ce4 139         crop_img = self.bwimg[0:height, :]
511c2e 140         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
SP 141         loc = np.where(res >= threshold)
142         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
143         # remove false matching of the squares in qr code
144         loc_filtered_x = []
145         loc_filtered_y = []
146         if len(loc[0]) == 0:
147             min_y = -1
148         else:
149             min_y = np.min(loc[0])
150             for pt in zip(*loc[::-1]):
151                 if pt[1] < min_y + 20:
152                     loc_filtered_y.append(pt[1])
153                     loc_filtered_x.append(pt[0])
154                     # order by x coordinate
155             loc_filtered_x, loc_filtered_y = zip(
156                 *sorted(zip(loc_filtered_x, loc_filtered_y))
157             )
02e0f7 158             # loc=[loc_filtered_y,loc_filtered_x]
SP 159             # remove duplicates
511c2e 160             a = np.diff(loc_filtered_x) > 40
SP 161             a = np.append(a, True)
162             loc_filtered_x = np.array(loc_filtered_x)
163             loc_filtered_y = np.array(loc_filtered_y)
164             loc = [loc_filtered_y[a], loc_filtered_x[a]]
165             for pt in zip(*loc[::-1]):
166                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 167
0d97e9 168         cv2.imwrite("/tmp/debug_3.png", cimg)
e555c0 169
511c2e 170         self.xMarkerLocations = loc
SP 171         return loc
e555c0 172
511c2e 173     def locateRightMarkers(self, threshold=0.85, width=200):
0d97e9 174         template = cv2.imread(markerfilename, 0)
511c2e 175         w, h = template.shape[::-1]
d88ce4 176         crop_img = self.bwimg[:, -width:]
SP 177         cv2.imwrite('/tmp/debug_right.png', crop_img)
511c2e 178         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
SP 179         loc = np.where(res >= threshold)
180         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
181         # remove false matching of the squares in qr code
182         loc_filtered_x = []
183         loc_filtered_y = []
184         if len(loc[1]) == 0:
185             min_x = -1
186         else:
187             max_x = np.max(loc[1])
188             for pt in zip(*loc[::-1]):
189                 if pt[1] > max_x - 20:
190                     loc_filtered_y.append(pt[1])
191                     loc_filtered_x.append(pt[0])
192                     # order by y coordinate
d88ce4 193             try:
SP 194                 loc_filtered_y, loc_filtered_x = zip(
195                     *sorted(zip(loc_filtered_y, loc_filtered_x))
196                 )
197             except:
198                 self.yMarkerLocations=[np.array([1,1]),np.array([1,2])]
199                 return self.yMarkerLocations
511c2e 200             # loc=[loc_filtered_y,loc_filtered_x]
SP 201             # remove duplicates
202             a = np.diff(loc_filtered_y) > 40
203             a = np.append(a, True)
204             loc_filtered_x = np.array(loc_filtered_x)
205             loc_filtered_y = np.array(loc_filtered_y)
206             loc = [loc_filtered_y[a], loc_filtered_x[a]]
207             for pt in zip(*loc[::-1]):
208                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 209
0d97e9 210         cv2.imwrite("/tmp/debug_4.png", cimg)
e555c0 211
511c2e 212         self.yMarkerLocations = [loc[0], loc[1] + self.imgWidth - width]
SP 213         return self.yMarkerLocations
e555c0 214
511c2e 215     def generateAnswerMatrix(self):
SP 216         self.locateUpMarkers()
217         self.locateRightMarkers()
e555c0 218
511c2e 219         roixoff = 10
SP 220         roiyoff = 5
221         roiwidth = 50
222         roiheight = roiwidth
223         totpx = roiwidth * roiheight
e555c0 224
511c2e 225         self.answerMatrix = []
SP 226         for y in self.yMarkerLocations[0]:
227             oneline = []
228             for x in self.xMarkerLocations[1]:
229                 roi = self.bwimg[
230                     y - roiyoff : y + int(roiheight - roiyoff),
231                     x - roixoff : x + int(roiwidth - roixoff),
232                 ]
233                 # cv2.imwrite('ans_x'+str(x)+'_y_'+str(y)+'.png',roi)
234                 black = totpx - cv2.countNonZero(roi)
235                 oneline.append(black / totpx)
236             self.answerMatrix.append(oneline)
9efc18 237
SP 238     def get_enhanced_sid(self):
02e0f7 239         if self.sid_classifier is None:
SP 240             return "x"
762a5e 241         if self.settings is not None:
e0996e 242             sid_mask = self.settings.get("sid_mask", None)
SP 243         es, err, warn = getSID(
02e0f7 244             self.img[
d5c694 245                 int(0.04 * self.imgHeight) : int(0.095 * self.imgHeight),
9c222b 246                 int(0.65 * self.imgWidth) : int(0.95 * self.imgWidth),
02e0f7 247             ],
SP 248             self.sid_classifier,
e0996e 249             sid_mask,
02e0f7 250         )
5cb7c1 251         [self.errors.append(e) for e in err]
SP 252         [self.warnings.append(w) for w in warn]
02e0f7 253         return es
0436f6 254
SP 255     def get_code_data(self):
cf921b 256         if self.QRData is None:
SP 257             self.errors.append("Could not read QR or EAN code! Not an exam?")
e0996e 258             retval = {
SP 259                 "exam_id": None,
260                 "page_no": None,
261                 "paper_id": None,
262                 "faculty_id": None,
263                 "sid": None,
0436f6 264             }
e0996e 265             return retval
SP 266         qrdata = bytes.decode(self.QRData, "utf8")
267         if self.QRDecode[0].type == "EAN13":
268             return {
269                 "exam_id": int(qrdata[0:7]),
69abed 270                 "page_no": int(qrdata[7])+1,
e0996e 271                 "paper_id": int(qrdata[-5:-1]),
SP 272                 "faculty_id": None,
273                 "sid": None,
274             }
275         else:
276             data = qrdata.split(",")
277             retval = {
278                 "exam_id": int(data[1]),
9c222b 279                 "page_no": int(data[3]),
e0996e 280                 "paper_id": int(data[2]),
SP 281                 "faculty_id": int(data[0]),
d88ce4 282                 "sid": None
e0996e 283             }
SP 284             if len(data) > 4:
285                 retval["sid"] = data[4]
0436f6 286
SP 287             return retval
288
289     def get_paper_ocr_data(self):
e0996e 290         data = self.get_code_data()
69abed 291         data["qr"] = bytes.decode(self.QRData, 'utf8')
e0996e 292         data["errors"] = self.errors
SP 293         data["warnings"] = self.warnings
294         data["up_position"] = (
9c222b 295             list(self.xMarkerLocations[0] / self.imgWidth),
SP 296             list(self.xMarkerLocations[1] / self.imgHeight),
e0996e 297         )
SP 298         data["right_position"] = (
9c222b 299             list(self.yMarkerLocations[0] / self.imgWidth),
e0996e 300             list(self.yMarkerLocations[1] / self.imgHeight),
SP 301         )
302         data["ans_matrix"] = (
303             (np.array(self.answerMatrix) > self.settings["answer_threshold"]) * 1
304         ).tolist()
9c222b 305         if data["sid"] is None and data["page_no"] == 1:
e0996e 306             data["sid"] = self.get_enhanced_sid()
69abed 307         output_filename=os.path.join(self.output_path, '.'.join(self.filename.split('/')[-1].split('.')[:-1])+".png")
SP 308         cv2.imwrite(output_filename, self.img)
309         data['output_filename']=output_filename
0436f6 310         return data