Coverage for src/pytest_patterns/plugin.py: 71%

168 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-06-14 10:28 +0200

1import enum 

2import re 

3from typing import Iterable, List, Set, Tuple 

4 

5import pytest 

6 

7 

8@pytest.fixture 

9def patterns(): 

10 yield PatternsLib() 

11 

12 

13def pytest_assertrepr_compare(op, left, right): 

14 if op != "==": 

15 return 

16 if left.__class__.__name__ == "Pattern": 

17 return list(left._audit(right).report()) 

18 elif right.__class__.__name__ == "Pattern": 

19 return list(right._audit(left).report()) 

20 

21 

22class Status(enum.Enum): 

23 UNEXPECTED = 1 

24 OPTIONAL = 2 

25 EXPECTED = 3 

26 REFUSED = 4 

27 

28 @property 

29 def symbol(self): 

30 return STATUS_SYMBOLS[self] 

31 

32 

33STATUS_SYMBOLS = { 

34 Status.UNEXPECTED: "🟡", 

35 Status.EXPECTED: "🟢", 

36 Status.OPTIONAL: "⚪️", 

37 Status.REFUSED: "🔴", 

38} 

39 

40EMPTY_LINE_PATTERN = "<empty-line>" 

41 

42 

43def match(pattern, line): 

44 if pattern == EMPTY_LINE_PATTERN: 

45 if not line: 

46 return True 

47 pattern = pattern.replace("\t", " " * 8) 

48 line = line.replace("\t", " " * 8) 

49 pattern = re.escape(pattern) 

50 pattern = pattern.replace(r"\.\.\.", ".*?") 

51 pattern = re.compile("^" + pattern + "$") 

52 return pattern.match(line) 

53 

54 

55class Line: 

56 status: Status = Status.UNEXPECTED 

57 status_cause: str = "" 

58 

59 def __init__(self, data: str): 

60 self.data = data 

61 

62 def matches(self, expectation: str): 

63 return bool(match(expectation, self.data)) 

64 

65 def mark(self, status: Status, cause: str): 

66 if status.value <= self.status.value: 

67 # Stay in the current status 

68 return 

69 self.status = status 

70 self.status_cause = cause 

71 

72 

73class Audit: 

74 content: List[Line] 

75 unmatched_expectations: List[Tuple[str, str]] 

76 matched_refused: Set[Tuple[str, str]] 

77 

78 def __init__(self, content: str): 

79 self.unmatched_expectations = [] 

80 self.matched_refused = set() 

81 

82 self.content = [] 

83 for line in content.splitlines(): 

84 self.content.append(Line(line)) 

85 

86 def cursor(self): 

87 return iter(self.content) 

88 

89 def in_order(self, name: str, expected_lines: List[str]): 

90 """Expect all lines exist and come in order, but they 

91 may be interleaved with other lines.""" 

92 cursor = self.cursor() 

93 have_some_match = False 

94 for expected_line in expected_lines: 

95 for line in cursor: 

96 if line.matches(expected_line): 

97 line.mark(Status.EXPECTED, name) 

98 have_some_match = True 

99 break 

100 else: 

101 self.unmatched_expectations.append((name, expected_line)) 

102 if not have_some_match: 102 ↛ 107line 102 didn't jump to line 107, because the condition on line 102 was never true

103 # Reset the scan, if we didn't have any previous 

104 # match - maybe a later line will produce a partial match. 

105 # But do not reset if we already have something matching, 

106 # because that would defeat the "in order" assumption. 

107 cursor = self.cursor() 

108 

109 def optional(self, name: str, tolerated_lines: List[str]): 

110 """Those lines may exist and then they may appear anywhere 

111 a number of times, or they may not exist. 

112 """ 

113 for tolerated_line in tolerated_lines: 

114 for line in self.cursor(): 

115 if line.matches(tolerated_line): 

116 line.mark(Status.OPTIONAL, name) 

117 

118 def refused(self, name: str, refused_lines: List[str]): 

119 for refused_line in refused_lines: 

120 for line in self.cursor(): 

121 if line.matches(refused_line): 

122 line.mark(Status.REFUSED, name) 

123 self.matched_refused.add((name, refused_line)) 

124 

125 def continuous(self, name: str, continuous_lines: List[str]): 

126 continuous_cursor = enumerate(continuous_lines) 

127 continuous_index, continuous_line = next(continuous_cursor) 

128 for line in self.cursor(): 

129 if continuous_index and not line.data: 

130 # Continuity still allows empty lines (after the first line) in 

131 # between as we filter them out from the pattern to make those 

132 # more readable. 

133 line.mark(Status.OPTIONAL, name) 

134 continue 

135 if line.matches(continuous_line): 

136 line.mark(Status.EXPECTED, name) 

137 try: 

138 continuous_index, continuous_line = next(continuous_cursor) 

139 except StopIteration: 

140 # We exhausted the pattern and are happy. 

141 break 

142 elif continuous_index: 

143 # This is not the first focus line any more, it's not valid to 

144 # not match 

145 line.mark(Status.REFUSED, name) 

146 self.unmatched_expectations.append((name, continuous_line)) 

147 self.unmatched_expectations.extend( 

148 [(name, line) for i, line in continuous_cursor] 

149 ) 

150 break 

151 else: 

152 self.unmatched_expectations.append((name, continuous_line)) 

153 self.unmatched_expectations.extend( 

154 [(name, line) for i, line in continuous_cursor] 

155 ) 

156 

157 def report(self): 

158 yield "String did not meet the expectations." 

159 yield "" 

160 yield " | ".join( 

161 [ 

162 Status.EXPECTED.symbol + "=EXPECTED", 

163 Status.OPTIONAL.symbol + "=OPTIONAL", 

164 Status.UNEXPECTED.symbol + "=UNEXPECTED", 

165 Status.REFUSED.symbol + "=REFUSED/UNMATCHED", 

166 ] 

167 ) 

168 yield "" 

169 yield "Here is the string that was tested: " 

170 yield "" 

171 for line in self.content: 

172 yield format_line_report( 

173 line.status.symbol, line.status_cause, line.data 

174 ) 

175 if self.unmatched_expectations: 

176 yield "" 

177 yield "These are the unmatched expected lines: " 

178 yield "" 

179 for name, line in self.unmatched_expectations: 

180 yield format_line_report(Status.REFUSED.symbol, name, line) 

181 if self.matched_refused: 

182 yield "" 

183 yield "These are the matched refused lines: " 

184 yield "" 

185 for name, line in self.matched_refused: 

186 yield format_line_report(Status.REFUSED.symbol, name, line) 

187 

188 def is_ok(self): 

189 if self.unmatched_expectations: 

190 return False 

191 for line in self.content: 

192 if line.status not in [Status.EXPECTED, Status.OPTIONAL]: 

193 return False 

194 return True 

195 

196 

197def format_line_report(symbol, cause, line): 

198 return symbol + " " + cause.ljust(15)[:15] + " | " + line 

199 

200 

201def pattern_lines(lines: str) -> List[str]: 

202 # Remove leading whitespace, ignore empty lines. 

203 return list(filter(None, lines.splitlines())) 

204 

205 

206class Pattern: 

207 def __init__(self, library, name): 

208 self.name = name 

209 self.library = library 

210 self.ops = [] 

211 self.inherited = set() 

212 

213 # Modifiers (Verbs) 

214 

215 def merge(self, *base_patterns): 

216 """Merge the rules from those patterns (recursively) into this pattern.""" 

217 self.inherited.update(base_patterns) 

218 

219 def normalize(self, mode: str): 

220 pass 

221 

222 # Matches (Adjectives) 

223 

224 def continuous(self, lines: str): 

225 """These lines must appear once and they must be continuous.""" 

226 self.ops.append(("continuous", self.name, pattern_lines(lines))) 

227 

228 def in_order(self, lines: str): 

229 """These lines must appear once and they must be in order.""" 

230 self.ops.append(("in_order", self.name, pattern_lines(lines))) 

231 

232 def optional(self, lines: str): 

233 """These lines are optional.""" 

234 self.ops.append(("optional", self.name, pattern_lines(lines))) 

235 

236 def refused(self, lines: str): 

237 """If those lines appear they are refused.""" 

238 self.ops.append(("refused", self.name, pattern_lines(lines))) 

239 

240 # Internal API 

241 

242 def flat_ops(self): 

243 for inherited_pattern in self.inherited: 

244 yield from getattr(self.library, inherited_pattern).flat_ops() 

245 yield from self.ops 

246 

247 def _audit(self, content): 

248 audit = Audit(content) 

249 for op, *args in self.flat_ops(): 

250 getattr(audit, op)(*args) 

251 return audit 

252 

253 def __eq__(self, other): 

254 assert isinstance(other, str) 

255 audit = self._audit(other) 

256 return audit.is_ok() 

257 

258 

259class PatternsLib: 

260 def __getattr__(self, name): 

261 self.__dict__[name] = Pattern(self, name) 

262 return self.__dict__[name]