Computer Science/Algorithm

[Python 3] BOJ 1285 동전 뒤집기

무니화니 2025. 2. 17. 11:54

 

비트마스킹과 브루트포스, (그리디)를 이용한 문제였다.

먼저, 행을 기준으로 뒤집을 수 있는 모든 경우에 대해서 뒤집기를 실시한다.

n이 20인 경우 2의 20제곱 (약 100만번) 정도의 뒤집기를 한다고 생각하면 된다.

모든 경우를 다 뒤집기 때문에 브루트 포스이다.

또한, 여기서 비트마스킹을 사용한다. 

0부터 2^(n-1)-1까지의 숫자를 이용해서, 각 비트에 1이 있으면 뒤집기를, 없으면 뒤집기를 실시하지 않는다.

 

for bit in range(1<<n):
    
        if bit&(1<<i):

예컨데, n이 3개짜리인 테이블을 예시로 보자.

H H T
T H T
H T H

[표 1: 현재 테이블]

 

다음과 같은 3행에 뒤집기를 하고 싶다. 그렇다면 ,

비트마스킹에서 '100'의 경우일 것이다.

즉, 1행과 2행에서는 뒤집지 않고, 3번째 비트인 1에서만 뒤집는다.

H H T
T H T
T H T

[표 2: '100' 비트]

 

두 행 이상을 뒤집는 것도 가능하다.

예를 들면, 1행과 2행을 뒤집고, 3번째 행은 그대로 두고 싶다고 하면,

비트마스킹에서 '001'와 같은 경우일 것이다.

 T T H
H T H
H T H

[표 3: '011' 비트]

 

여기서 행에 대한 뒤집기를 모두 마쳤다.

이제 열에 대해서 고민해보자. 각 열의 뒷면과 앞면의 개수를 세어보자.

만약 뒷면이 더 많으면, 더 이상 뒤집을 필요가 없다.

하지만, 앞면이 더 많으면, 해당 열을 기준으로 뒤집어서 동전의 뒷면이 더 많아지게 하면 된다.

tempsum=0
    for j in range(n):
        count=0
        for i in range(n):
            if temp[i][j]=='H':
                count+=1
        tempsum+=min(count,n-count)
    answer=min(answer,tempsum)

 

 

 

깊은 복사를 통해 data의 내용이 매번 새로운 인스턴스로 정의되게 하였다.

temp=[data[i][:] for i in range(n)]

 

 

마지막으로, 3단 for문을 사용해기도 하고, 시간 복잡도를 따져보았을 때 O(2^n * n^2)이다.

굉장히 괴랄한 복잡도이다. 실제로는 n이 조금만 더 커져도, 채점이 평생 걸릴거고, TLE는 거의 확실하다.

 

하지만, 시간 제한이 6초이다.

또한 pypy3로 제출할 것이기 때문에 , 시간 제한에 6*3+2초, 총 20초를 부여해준다.

(https://help.acmicpc.net/language/info 해당 내용은 여기서 참고 가능하다.)

그리고 n 또한 제한이 20이기에 가능한 풀이이다.

 

n의 크기가 작으면 비트마스킹을 생각해보아야겠다.

 

import copy
import sys
input=sys.stdin.readline
n=int(input())
data=list(list(input().strip()) for _ in range(n))
answer=1e10
for bit in range(1<<n):
    temp=[data[i][:] for i in range(n)]
    for i in range(n):
        if bit&(1<<i):
            for j in range(n):
                if temp[i][j]=='H':
                    temp[i][j]='T'
                else:
                    temp[i][j]='H'
    tempsum=0
    for j in range(n):
        count=0
        for i in range(n):
            if temp[i][j]=='H':
                count+=1
        tempsum+=min(count,n-count)
    answer=min(answer,tempsum)
print(answer)