Development of the ocr part of AOI
Samo Penic
2018-11-18 2fec6b84c1ebc8ea0c257185b83266aae9f57639
commit | author | age
e555c0 1 from pyzbar.pyzbar import decode
0d97e9 2 from .sid_process import getSID
e555c0 3 import cv2
SP 4 import numpy as np
69abed 5 import os
0d97e9 6 import pkg_resources
SP 7
8 markerfile = '/template.png'  # always use slash
9 markerfilename = pkg_resources.resource_filename(__name__, markerfile)
10
e555c0 11
SP 12
511c2e 13 class Paper:
69abed 14     def __init__(self, filename=None, sid_classifier=None, settings=None, output_path="/tmp"):
511c2e 15         self.filename = filename
69abed 16         self.output_path=output_path
511c2e 17         self.invalid = None
SP 18         self.QRData = None
e0996e 19         self.settings = {"answer_threshold": 0.25} if settings is None else settings
511c2e 20         self.errors = []
SP 21         self.warnings = []
e0996e 22         self.sid = None
762a5e 23         self.sid_classifier = sid_classifier
511c2e 24         if filename is not None:
SP 25             self.loadImage(filename)
26             self.runOcr()
e555c0 27
511c2e 28     def loadImage(self, filename, rgbchannel=0):
SP 29         self.img = cv2.imread(filename, rgbchannel)
30         if self.img is None:
31             self.errors.append("File could not be loaded!")
32             self.invalid = True
33             return
34         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 35
0d97e9 36     def saveImage(self, filename="/tmp/debug_image.png"):
511c2e 37         cv2.imwrite(filename, self.img)
e555c0 38
511c2e 39     def runOcr(self):
SP 40         if self.invalid == True:
41             return
42         self.decodeQRandRotate()
43         self.imgTreshold()
44         skewAngle = 0
45         #         try:
46         #             skewAngle=self.getSkewAngle()
47         #         except:
48         #             self.errors.append("Could not determine skew angle!")
49         #         self.rotateAngle(skewAngle)
e555c0 50
511c2e 51         self.generateAnswerMatrix()
e555c0 52
511c2e 53         self.saveImage()
e555c0 54
511c2e 55     def decodeQRandRotate(self):
SP 56         if self.invalid == True:
57             return
58         blur = cv2.blur(self.img, (3, 3))
59         d = decode(blur)
60         self.img = blur
61         if len(d) == 0:
62             self.errors.append("QR code could not be found!")
63             self.data = None
64             self.invalid = True
65             return
e0996e 66         if(len(d)>1): #if there are multiple codes, get first ean or qr code available.
SP 67             for dd in d:
68                 if(dd.type=="EAN13" or dd.type=="QR"):
69                     d[0]=dd
70                     break
511c2e 71         self.QRDecode = d
SP 72         self.QRData = d[0].data
73         xpos = d[0].rect.left
74         ypos = d[0].rect.top
75         # check if image is rotated wrongly
82ec6d 76         if xpos > self.imgHeight / 2.0 and ypos > self.imgWidth / 2.0:
511c2e 77             self.rotateAngle(180)
e555c0 78
511c2e 79     def rotateAngle(self, angle=0):
e0996e 80         # rot_mat = cv2.getRotationMatrix2D(
82ec6d 81         #    (self.imgHeight / 2, self.imgWidth / 2), angle, 1.0
e0996e 82         # )
511c2e 83         rot_mat = cv2.getRotationMatrix2D(
e0996e 84             (self.imgWidth / 2, self.imgHeight / 2), angle, 1.0
511c2e 85         )
SP 86         result = cv2.warpAffine(
87             self.img,
88             rot_mat,
82ec6d 89             (self.imgWidth, self.imgHeight),
511c2e 90             flags=cv2.INTER_CUBIC,
SP 91             borderMode=cv2.BORDER_CONSTANT,
92             borderValue=(255, 255, 255),
93         )
e555c0 94
511c2e 95         self.img = result
SP 96         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 97
511c2e 98         # todo, make better tresholding
e555c0 99
511c2e 100     def imgTreshold(self):
SP 101         (self.thresh, self.bwimg) = cv2.threshold(
102             self.img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU
103         )
e555c0 104
511c2e 105     def getSkewAngle(self):
SP 106         neg = 255 - self.bwimg  # get negative image
0d97e9 107         cv2.imwrite("/tmp/debug_1.png", neg)
e555c0 108
511c2e 109         angle_counter = 0  # number of angles
SP 110         angle = 0.0  # collects sum of angles
111         cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
e555c0 112
511c2e 113         # get all the Hough lines
SP 114         for line in cv2.HoughLinesP(neg, 1, np.pi / 180, 325):
115             x1, y1, x2, y2 = line[0]
116             cv2.line(cimg, (x1, y1), (x2, y2), (0, 0, 255), 2)
117             # calculate the angle (in radians)
118             this_angle = np.arctan2(y2 - y1, x2 - x1)
119             if this_angle and abs(this_angle) <= 10:
120                 # filtered zero degree and outliers
121                 angle += this_angle
122                 angle_counter += 1
e555c0 123
511c2e 124                 # the skew is calculated of the mean of the total angles, #try block helps with division by zero.
SP 125         try:
126             skew = np.rad2deg(
127                 angle / angle_counter
128             )  # the 1.2 factor is just experimental....
129         except:
130             skew = 0
e555c0 131
0d97e9 132         cv2.imwrite("/tmp/debug_2.png", cimg)
511c2e 133         return skew
e555c0 134
511c2e 135     def locateUpMarkers(self, threshold=0.85, height=200):
0d97e9 136         template = cv2.imread(markerfilename, 0)
511c2e 137         w, h = template.shape[::-1]
SP 138         crop_img = self.img[0:height, :]
139         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
140         loc = np.where(res >= threshold)
141         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
142         # remove false matching of the squares in qr code
143         loc_filtered_x = []
144         loc_filtered_y = []
145         if len(loc[0]) == 0:
146             min_y = -1
147         else:
148             min_y = np.min(loc[0])
149             for pt in zip(*loc[::-1]):
150                 if pt[1] < min_y + 20:
151                     loc_filtered_y.append(pt[1])
152                     loc_filtered_x.append(pt[0])
153                     # order by x coordinate
154             loc_filtered_x, loc_filtered_y = zip(
155                 *sorted(zip(loc_filtered_x, loc_filtered_y))
156             )
02e0f7 157             # loc=[loc_filtered_y,loc_filtered_x]
SP 158             # remove duplicates
511c2e 159             a = np.diff(loc_filtered_x) > 40
SP 160             a = np.append(a, True)
161             loc_filtered_x = np.array(loc_filtered_x)
162             loc_filtered_y = np.array(loc_filtered_y)
163             loc = [loc_filtered_y[a], loc_filtered_x[a]]
164             for pt in zip(*loc[::-1]):
165                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 166
0d97e9 167         cv2.imwrite("/tmp/debug_3.png", cimg)
e555c0 168
511c2e 169         self.xMarkerLocations = loc
SP 170         return loc
e555c0 171
511c2e 172     def locateRightMarkers(self, threshold=0.85, width=200):
0d97e9 173         template = cv2.imread(markerfilename, 0)
511c2e 174         w, h = template.shape[::-1]
SP 175         crop_img = self.img[:, -width:]
176         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
177         loc = np.where(res >= threshold)
178         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
179         # remove false matching of the squares in qr code
180         loc_filtered_x = []
181         loc_filtered_y = []
182         if len(loc[1]) == 0:
183             min_x = -1
184         else:
185             max_x = np.max(loc[1])
186             for pt in zip(*loc[::-1]):
187                 if pt[1] > max_x - 20:
188                     loc_filtered_y.append(pt[1])
189                     loc_filtered_x.append(pt[0])
190                     # order by y coordinate
191             loc_filtered_y, loc_filtered_x = zip(
192                 *sorted(zip(loc_filtered_y, loc_filtered_x))
193             )
194             # loc=[loc_filtered_y,loc_filtered_x]
195             # remove duplicates
196             a = np.diff(loc_filtered_y) > 40
197             a = np.append(a, True)
198             loc_filtered_x = np.array(loc_filtered_x)
199             loc_filtered_y = np.array(loc_filtered_y)
200             loc = [loc_filtered_y[a], loc_filtered_x[a]]
201             for pt in zip(*loc[::-1]):
202                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 203
0d97e9 204         cv2.imwrite("/tmp/debug_4.png", cimg)
e555c0 205
511c2e 206         self.yMarkerLocations = [loc[0], loc[1] + self.imgWidth - width]
SP 207         return self.yMarkerLocations
e555c0 208
511c2e 209     def generateAnswerMatrix(self):
SP 210         self.locateUpMarkers()
211         self.locateRightMarkers()
e555c0 212
511c2e 213         roixoff = 10
SP 214         roiyoff = 5
215         roiwidth = 50
216         roiheight = roiwidth
217         totpx = roiwidth * roiheight
e555c0 218
511c2e 219         self.answerMatrix = []
SP 220         for y in self.yMarkerLocations[0]:
221             oneline = []
222             for x in self.xMarkerLocations[1]:
223                 roi = self.bwimg[
224                     y - roiyoff : y + int(roiheight - roiyoff),
225                     x - roixoff : x + int(roiwidth - roixoff),
226                 ]
227                 # cv2.imwrite('ans_x'+str(x)+'_y_'+str(y)+'.png',roi)
228                 black = totpx - cv2.countNonZero(roi)
229                 oneline.append(black / totpx)
230             self.answerMatrix.append(oneline)
9efc18 231
SP 232     def get_enhanced_sid(self):
02e0f7 233         if self.sid_classifier is None:
SP 234             return "x"
762a5e 235         if self.settings is not None:
e0996e 236             sid_mask = self.settings.get("sid_mask", None)
SP 237         es, err, warn = getSID(
02e0f7 238             self.img[
d5c694 239                 int(0.04 * self.imgHeight) : int(0.095 * self.imgHeight),
02e0f7 240                 int(0.7 * self.imgWidth) : int(0.99 * self.imgWidth),
SP 241             ],
242             self.sid_classifier,
e0996e 243             sid_mask,
02e0f7 244         )
5cb7c1 245         [self.errors.append(e) for e in err]
SP 246         [self.warnings.append(w) for w in warn]
02e0f7 247         return es
0436f6 248
SP 249     def get_code_data(self):
cf921b 250         if self.QRData is None:
SP 251             self.errors.append("Could not read QR or EAN code! Not an exam?")
e0996e 252             retval = {
SP 253                 "exam_id": None,
254                 "page_no": None,
255                 "paper_id": None,
256                 "faculty_id": None,
257                 "sid": None,
0436f6 258             }
e0996e 259             return retval
SP 260         qrdata = bytes.decode(self.QRData, "utf8")
261         if self.QRDecode[0].type == "EAN13":
262             return {
263                 "exam_id": int(qrdata[0:7]),
69abed 264                 "page_no": int(qrdata[7])+1,
e0996e 265                 "paper_id": int(qrdata[-5:-1]),
SP 266                 "faculty_id": None,
267                 "sid": None,
268             }
269         else:
270             data = qrdata.split(",")
271             retval = {
272                 "exam_id": int(data[1]),
69abed 273                 "page_no": int(data[3])+1,
e0996e 274                 "paper_id": int(data[2]),
SP 275                 "faculty_id": int(data[0]),
276             }
277             if len(data) > 4:
278                 retval["sid"] = data[4]
0436f6 279
SP 280             return retval
281
282     def get_paper_ocr_data(self):
e0996e 283         data = self.get_code_data()
69abed 284         data["qr"] = bytes.decode(self.QRData, 'utf8')
e0996e 285         data["errors"] = self.errors
SP 286         data["warnings"] = self.warnings
287         data["up_position"] = (
288             list(self.xMarkerLocations[1] / self.imgWidth),
289             list(self.yMarkerLocations[1] / self.imgHeight),
290         )
291         data["right_position"] = (
292             list(self.xMarkerLocations[1] / self.imgWidth),
293             list(self.yMarkerLocations[1] / self.imgHeight),
294         )
295         data["ans_matrix"] = (
296             (np.array(self.answerMatrix) > self.settings["answer_threshold"]) * 1
297         ).tolist()
69abed 298         if data["sid"] is None and data["page_no"] == 1:
e0996e 299             data["sid"] = self.get_enhanced_sid()
69abed 300         output_filename=os.path.join(self.output_path, '.'.join(self.filename.split('/')[-1].split('.')[:-1])+".png")
SP 301         cv2.imwrite(output_filename, self.img)
302         data['output_filename']=output_filename
0436f6 303         return data