input = open("input", 'r') # d = [line for line in input] # d = [[val for val in line.strip()] for line in input] d = {(complex(x,y)): c for y,line in enumerate(input) for x,c in enumerate(line.strip())} def run(p2=False): regionPrice = [] s = set() for point, plot in d.items(): if point in s: continue q = [point] area = 0 sides = set() while q: point = q.pop() if point in s: continue s.add(point) area += 1 for dr in {1,-1,1j,-1j}: npoint = point + dr if npoint in d and d[npoint] == plot: q.append(npoint) else: sides.add((dr, point)) if p2: nsides = sides.copy() for side in sides: if side not in nsides: continue sdr, spoint = side for dr in {1,-1,1j,-1j} - {sdr, -sdr}: point = spoint + dr while (sdr, point) in nsides: nsides.discard((sdr,point)) point = point + dr sides = nsides regionPrice.append(len(sides) * area) return sum(regionPrice) print(run()) print(run(True))