Development of the ocr part of AOI
Samo Penic
2018-11-17 e0996e35862256114826a6314dc649972972a60c
commit | author | age
e555c0 1 from pyzbar.pyzbar import decode
02e0f7 2 from sid_process import getSID
e555c0 3 import cv2
SP 4 import numpy as np
5 import math
6
7
511c2e 8 class Paper:
762a5e 9     def __init__(self, filename=None, sid_classifier=None, settings=None):
511c2e 10         self.filename = filename
SP 11         self.invalid = None
12         self.QRData = None
e0996e 13         self.settings = {"answer_threshold": 0.25} if settings is None else settings
511c2e 14         self.errors = []
SP 15         self.warnings = []
e0996e 16         self.sid = None
762a5e 17         self.sid_classifier = sid_classifier
511c2e 18         if filename is not None:
SP 19             self.loadImage(filename)
20             self.runOcr()
e555c0 21
511c2e 22     def loadImage(self, filename, rgbchannel=0):
SP 23         self.img = cv2.imread(filename, rgbchannel)
24         if self.img is None:
25             self.errors.append("File could not be loaded!")
26             self.invalid = True
27             return
28         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 29
511c2e 30     def saveImage(self, filename="debug_image.png"):
SP 31         cv2.imwrite(filename, self.img)
e555c0 32
511c2e 33     def runOcr(self):
SP 34         if self.invalid == True:
35             return
36         self.decodeQRandRotate()
37         self.imgTreshold()
38         skewAngle = 0
39         #         try:
40         #             skewAngle=self.getSkewAngle()
41         #         except:
42         #             self.errors.append("Could not determine skew angle!")
43         #         self.rotateAngle(skewAngle)
e555c0 44
511c2e 45         self.generateAnswerMatrix()
e555c0 46
511c2e 47         self.saveImage()
e555c0 48
511c2e 49     def decodeQRandRotate(self):
SP 50         if self.invalid == True:
51             return
52         blur = cv2.blur(self.img, (3, 3))
53         d = decode(blur)
54         self.img = blur
55         if len(d) == 0:
56             self.errors.append("QR code could not be found!")
57             self.data = None
58             self.invalid = True
59             return
e0996e 60         if(len(d)>1): #if there are multiple codes, get first ean or qr code available.
SP 61             for dd in d:
62                 if(dd.type=="EAN13" or dd.type=="QR"):
63                     d[0]=dd
64                     break
511c2e 65         self.QRDecode = d
SP 66         self.QRData = d[0].data
67         xpos = d[0].rect.left
68         ypos = d[0].rect.top
69         # check if image is rotated wrongly
82ec6d 70         if xpos > self.imgHeight / 2.0 and ypos > self.imgWidth / 2.0:
511c2e 71             self.rotateAngle(180)
e555c0 72
511c2e 73     def rotateAngle(self, angle=0):
e0996e 74         # rot_mat = cv2.getRotationMatrix2D(
82ec6d 75         #    (self.imgHeight / 2, self.imgWidth / 2), angle, 1.0
e0996e 76         # )
511c2e 77         rot_mat = cv2.getRotationMatrix2D(
e0996e 78             (self.imgWidth / 2, self.imgHeight / 2), angle, 1.0
511c2e 79         )
SP 80         result = cv2.warpAffine(
81             self.img,
82             rot_mat,
82ec6d 83             (self.imgWidth, self.imgHeight),
511c2e 84             flags=cv2.INTER_CUBIC,
SP 85             borderMode=cv2.BORDER_CONSTANT,
86             borderValue=(255, 255, 255),
87         )
e555c0 88
511c2e 89         self.img = result
SP 90         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 91
511c2e 92         # todo, make better tresholding
e555c0 93
511c2e 94     def imgTreshold(self):
SP 95         (self.thresh, self.bwimg) = cv2.threshold(
96             self.img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU
97         )
e555c0 98
511c2e 99     def getSkewAngle(self):
SP 100         neg = 255 - self.bwimg  # get negative image
101         cv2.imwrite("debug_1.png", neg)
e555c0 102
511c2e 103         angle_counter = 0  # number of angles
SP 104         angle = 0.0  # collects sum of angles
105         cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
e555c0 106
511c2e 107         # get all the Hough lines
SP 108         for line in cv2.HoughLinesP(neg, 1, np.pi / 180, 325):
109             x1, y1, x2, y2 = line[0]
110             cv2.line(cimg, (x1, y1), (x2, y2), (0, 0, 255), 2)
111             # calculate the angle (in radians)
112             this_angle = np.arctan2(y2 - y1, x2 - x1)
113             if this_angle and abs(this_angle) <= 10:
114                 # filtered zero degree and outliers
115                 angle += this_angle
116                 angle_counter += 1
e555c0 117
511c2e 118                 # the skew is calculated of the mean of the total angles, #try block helps with division by zero.
SP 119         try:
120             skew = np.rad2deg(
121                 angle / angle_counter
122             )  # the 1.2 factor is just experimental....
123         except:
124             skew = 0
e555c0 125
511c2e 126         cv2.imwrite("debug_2.png", cimg)
SP 127         return skew
e555c0 128
511c2e 129     def locateUpMarkers(self, threshold=0.85, height=200):
SP 130         template = cv2.imread("template.png", 0)
131         w, h = template.shape[::-1]
132         crop_img = self.img[0:height, :]
133         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
134         loc = np.where(res >= threshold)
135         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
136         # remove false matching of the squares in qr code
137         loc_filtered_x = []
138         loc_filtered_y = []
139         if len(loc[0]) == 0:
140             min_y = -1
141         else:
142             min_y = np.min(loc[0])
143             for pt in zip(*loc[::-1]):
144                 if pt[1] < min_y + 20:
145                     loc_filtered_y.append(pt[1])
146                     loc_filtered_x.append(pt[0])
147                     # order by x coordinate
148             loc_filtered_x, loc_filtered_y = zip(
149                 *sorted(zip(loc_filtered_x, loc_filtered_y))
150             )
02e0f7 151             # loc=[loc_filtered_y,loc_filtered_x]
SP 152             # remove duplicates
511c2e 153             a = np.diff(loc_filtered_x) > 40
SP 154             a = np.append(a, True)
155             loc_filtered_x = np.array(loc_filtered_x)
156             loc_filtered_y = np.array(loc_filtered_y)
157             loc = [loc_filtered_y[a], loc_filtered_x[a]]
158             for pt in zip(*loc[::-1]):
159                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 160
511c2e 161         cv2.imwrite("debug_3.png", cimg)
e555c0 162
511c2e 163         self.xMarkerLocations = loc
SP 164         return loc
e555c0 165
511c2e 166     def locateRightMarkers(self, threshold=0.85, width=200):
SP 167         template = cv2.imread("template.png", 0)
168         w, h = template.shape[::-1]
169         crop_img = self.img[:, -width:]
170         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
171         loc = np.where(res >= threshold)
172         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
173         # remove false matching of the squares in qr code
174         loc_filtered_x = []
175         loc_filtered_y = []
176         if len(loc[1]) == 0:
177             min_x = -1
178         else:
179             max_x = np.max(loc[1])
180             for pt in zip(*loc[::-1]):
181                 if pt[1] > max_x - 20:
182                     loc_filtered_y.append(pt[1])
183                     loc_filtered_x.append(pt[0])
184                     # order by y coordinate
185             loc_filtered_y, loc_filtered_x = zip(
186                 *sorted(zip(loc_filtered_y, loc_filtered_x))
187             )
188             # loc=[loc_filtered_y,loc_filtered_x]
189             # remove duplicates
190             a = np.diff(loc_filtered_y) > 40
191             a = np.append(a, True)
192             loc_filtered_x = np.array(loc_filtered_x)
193             loc_filtered_y = np.array(loc_filtered_y)
194             loc = [loc_filtered_y[a], loc_filtered_x[a]]
195             for pt in zip(*loc[::-1]):
196                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 197
511c2e 198         cv2.imwrite("debug_4.png", cimg)
e555c0 199
511c2e 200         self.yMarkerLocations = [loc[0], loc[1] + self.imgWidth - width]
SP 201         return self.yMarkerLocations
e555c0 202
511c2e 203     def generateAnswerMatrix(self):
SP 204         self.locateUpMarkers()
205         self.locateRightMarkers()
e555c0 206
511c2e 207         roixoff = 10
SP 208         roiyoff = 5
209         roiwidth = 50
210         roiheight = roiwidth
211         totpx = roiwidth * roiheight
e555c0 212
511c2e 213         self.answerMatrix = []
SP 214         for y in self.yMarkerLocations[0]:
215             oneline = []
216             for x in self.xMarkerLocations[1]:
217                 roi = self.bwimg[
218                     y - roiyoff : y + int(roiheight - roiyoff),
219                     x - roixoff : x + int(roiwidth - roixoff),
220                 ]
221                 # cv2.imwrite('ans_x'+str(x)+'_y_'+str(y)+'.png',roi)
222                 black = totpx - cv2.countNonZero(roi)
223                 oneline.append(black / totpx)
224             self.answerMatrix.append(oneline)
9efc18 225
SP 226     def get_enhanced_sid(self):
02e0f7 227         if self.sid_classifier is None:
SP 228             return "x"
762a5e 229         if self.settings is not None:
e0996e 230             sid_mask = self.settings.get("sid_mask", None)
SP 231         es, err, warn = getSID(
02e0f7 232             self.img[
d5c694 233                 int(0.04 * self.imgHeight) : int(0.095 * self.imgHeight),
02e0f7 234                 int(0.7 * self.imgWidth) : int(0.99 * self.imgWidth),
SP 235             ],
236             self.sid_classifier,
e0996e 237             sid_mask,
02e0f7 238         )
5cb7c1 239         [self.errors.append(e) for e in err]
SP 240         [self.warnings.append(w) for w in warn]
02e0f7 241         return es
0436f6 242
SP 243     def get_code_data(self):
cf921b 244         if self.QRData is None:
SP 245             self.errors.append("Could not read QR or EAN code! Not an exam?")
e0996e 246             retval = {
SP 247                 "exam_id": None,
248                 "page_no": None,
249                 "paper_id": None,
250                 "faculty_id": None,
251                 "sid": None,
0436f6 252             }
e0996e 253             return retval
SP 254         qrdata = bytes.decode(self.QRData, "utf8")
255         if self.QRDecode[0].type == "EAN13":
256             return {
257                 "exam_id": int(qrdata[0:7]),
258                 "page_no": int(qrdata[7]),
259                 "paper_id": int(qrdata[-5:-1]),
260                 "faculty_id": None,
261                 "sid": None,
262             }
263         else:
264             data = qrdata.split(",")
265             retval = {
266                 "exam_id": int(data[1]),
267                 "page_no": int(data[3]),
268                 "paper_id": int(data[2]),
269                 "faculty_id": int(data[0]),
270             }
271             if len(data) > 4:
272                 retval["sid"] = data[4]
0436f6 273
SP 274             return retval
275
276     def get_paper_ocr_data(self):
e0996e 277         data = self.get_code_data()
SP 278         data["qr"] = self.QRData
279         data["errors"] = self.errors
280         data["warnings"] = self.warnings
281         data["up_position"] = (
282             list(self.xMarkerLocations[1] / self.imgWidth),
283             list(self.yMarkerLocations[1] / self.imgHeight),
284         )
285         data["right_position"] = (
286             list(self.xMarkerLocations[1] / self.imgWidth),
287             list(self.yMarkerLocations[1] / self.imgHeight),
288         )
289         data["ans_matrix"] = (
290             (np.array(self.answerMatrix) > self.settings["answer_threshold"]) * 1
291         ).tolist()
292         if data["sid"] is None and data["page_no"] == 0:
293             data["sid"] = self.get_enhanced_sid()
0436f6 294         return data