안녕하세요! 은공지능 공작소의 파이찬입니다.

올 여름은 비교적 시원한 느낌이 드네요.

그래도 컴퓨터 앞에 오래 앉아있으면 더운 느낌이 들곤하죠...

모두 학교에서 직장에서 파이팅 하시길 바랍니다!!

 

오늘은 tf.where 함수에 대해 알아보겠습니다.

이전 텐서플로 버전과는 다르게, 이펙티브 텐서플로 2.0에서는

where 함수에 broadcasting 기능이 추가가 되었습니다.

그럼 기본적인 사용법부터, broadcasting 개념까지

하나하나 알아보도록 하겠습니다.

 

 

 

 

 

 

 

 

1. 사전준비

 

 

 

import tensorflow as tf

print(tf.__version__) 
output:
2.0.0-beta1

텐서플로를 불러와 줍니다.

버전을 출력했을 때, 위의 결과창처럼 결과가 떠야 합니다.

이전 버전에서도 tf.where 함수가 있지만 broadcasting은 지원하지 않습니다.

만약 텐서플로 2.0 설치와 conda 가상환경 설정법이 궁금하신 분들은

아래의 포스팅을 참고하시면 되겠습니다.

 

https://chan-lab.tistory.com/11

 

Tensorflow 2.0 Beta!! 아나콘다 가상환경 설정하기 (For window 10)

안녕하세요! 은공지능 공작소의 파이찬입니다. 지금 포스팅을 쓰는 날짜 19년 7월 17일 기준 최신버전 텐서플로 2.0 Beta 버전이 출시가 된 상태입니다. 이에 맞춰 은공지능 공작소에서는 초보자도 따라하기 쉬운..

chan-lab.tistory.com

 

 

 

c1 = tf.constant([1, 1])
c2 = tf.constant([2, 2])
c3 = tf.constant([0, 3])
cond = tf.constant([True, False])

tf.where의 기본적인 사용법을 알려드리겠습니다.

위와 같이 4가지 텐서를 만들어 둡니다.

c1 ~ c3는 숫자 들어간 상수 텐서이고, cond 함수는 True/False를 넣었습니다.

 

 

 

c1
c1.numpy()
print(c1.numpy())
def prt(input):
    print(input.numpy()) # 리턴값은 없습니다.

텐서플로 2.0에서는 따로 세션을 통해 출력을 하지 않아도 됩니다.

그래도 깔끔한 출력을 위해서 내용물을 프린트 하는 함수를

하나 만들어 두고 가겠습니다.

 

 

 

2. where 함수의 기본적인 사용법

 

 

 

tf.where(  bool type 텐서,   True일 때 출력값,   False일 때 출력값  )
tensor list:

c1:[1 1]
c2:[2 2]
c3:[0 3]
cond:[ True False]
prt(tf.where(cond, c1, c2))
prt(tf.where(tf.less(c1, c2), c1, c2))
prt(tf.where(tf.greater(c1, c3), c1, c3))
output:

[1 2]
[1 1]
[1 3]

위의 코드를 실행하면서 결과를 확인해보시길 바랍니다.

각 텐서의 원소별로 True일 때, False일 때 출력되는 것이 달라집니다.

bool type 텐서의 원소가 True 일때와 False 일때의 출력값이

달라진다는 것이 포인트입니다.

 

 

 

 

 

3. where 함수를 통해 알아보는 Broadcasting

 

 

 

브로드캐스팅(broadcasting)을 한마디로 정의하자면 바로 '확장'입니다.

where 함수에서 서로 다른 모양의 2가지 텐서가 사용되는 경우,

서로의 shape에 맞추어 행과 열을 확장해야 할 필요가 있습니다.

이럴 때 사용되는 것이 브로드캐스팅, 즉 확장입니다.

아래 코드와 그림을 통해 더 자세히 설명드리겠습니다.

 

 

 

c4, c5 = tf.constant([[1], [2], [3]]), tf.constant([[1, 2, 3, 4]])
output:

[[1]
 [2]
 [3]]
[[1 2 3 4]]

 

c4는 3 x 1 텐서, c5는 1 x 4 텐서입니다. 이 두 텐서는 서로 모양이 다릅니다.

그렇기 때문에 tf.less같은 비교연산을 위해서는 각각 행과 열을 '확장'해야 합니다.

이렇게 확장된 후에는 서로 3 x 4로 shape가 같아집니다.

이것이 바로 브로드캐스팅(broadcasting)의 개념입니다.

 

 

 

prt(tf.where(tf.less(c4, c5), tf.multiply(c4, c4), tf.multiply(c5, -1)))
output:

[[-1  1  1  1]
 [-1 -2  4  4]
 [-1 -2 -3  9]]

 

이제 브로드캐스팅된 텐서 c4, c5를 통해 tf.less 연산이 수행이 됩니다.

연산은 각각의 원소마다 적용이 됩니다.

즉, c4의 1행 1열 원소는 c5의 1행 1열의 원소와 비교가 되는 방식입니다.

연산의 결과는 위에 보시는 대로 입니다.

 

 

 

 

이제 cond 결과에 따라서 True일때 출력값, False일때 출력값이 달라지게 됩니다.

초록색은 True와 연관이 있이고, 빨간색은 False와 연관이 있습니다.

즉, True일때는 c4 원소들을 제곱을 하고

False일 때는 c5 원소에 -1을 곱해주는 코드입니다.

 

 

 

prt(tf.where(tf.less(c4, c5), tf.multiply(c4, c4), tf.multiply(c5, -1)))
output:
[[-1  1  1  1]
 [-1 -2  4  4]
 [-1 -2 -3  9]]

결과를 다시 한 번 확인하면, 위와 같습니다.

 

 

 

이렇게 하여 오늘은 텐서플로 2.0 버전의 where 함수,

그리고 where함수의 브로드캐스팅 기능까지 알아보았습니다.

도움이 되셨다면 하단의 하트 버튼 한번씩 부탁드리겠습니다.

 

감사합니다.

 

블로그 이미지

pychan

딥러닝에 관련된 시행착오, 사소하지만 중요한 것들, 가능한 모든 여정을 담았습니다.

댓글을 달아 주세요

  • hcw 2019.11.24 11:35  댓글주소  수정/삭제  댓글쓰기

    이미지 중 tf.where의 broadcasting_(3) 에서 틀린점이 있는 것 같습니다. 아래쪽 행렬에서 (2,2) 위치가 좀 이상하네요.