input = open("input", 'r') # d = [line for line in input] d = [[val for val in line.strip()] for line in input] def xmas(x,y): t = 0 drs = {(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)} for dr in drs: drx, dry = dr nx, ny = x+drx,y+dry if nx < 0 or nx >= len(d) or ny < 0 or ny >= len(d[0]): continue if d[nx][ny] != "M": continue nx, ny = nx+drx,ny+dry if nx < 0 or nx >= len(d) or ny < 0 or ny >= len(d[0]): continue if d[nx][ny] != "A": continue nx, ny = nx+drx,ny+dry if nx < 0 or nx >= len(d) or ny < 0 or ny >= len(d[0]): continue if d[nx][ny] != "S": continue t += 1 return t def mas(x,y): nx,ny = x-1,y-1 mx,my = x+1,y+1 if nx < 0 or nx >= len(d) or ny < 0 or ny >= len(d[0]): return 0 if mx < 0 or mx >= len(d) or my < 0 or my >= len(d[0]): return 0 if d[nx][ny] + d[mx][my] in {"MS", "SM"}: nx,ny = x-1,y+1 mx,my = x+1,y-1 if nx < 0 or nx >= len(d) or ny < 0 or ny >= len(d[0]): return 0 if mx < 0 or mx >= len(d) or my < 0 or my >= len(d[0]): return 0 if d[nx][ny] + d[mx][my] in {"MS", "SM"}: return 1 return 0 def run(): sm = 0 for i, line in enumerate(d): for j, char in enumerate(line): if char == "X": sm += xmas(i,j) return sm def run2(): sm = 0 for i, line in enumerate(d): for j, char in enumerate(line): if char == "A": sm += mas(i,j) return sm print(run()) print(run2())