230508 / BSA10. pyspark에서 스팸 메일 분류

BSA07_Pyspark-SMS-Spam-Analysis.ipynb

 

패키지 호출

from pyspark.sql import SparkSession
from pyspark.ml.feature import Tokenizer, RegexTokenizer
from pyspark.ml.feature import CountVectorizer
from pyspark.sql.functions import col, udf
from pyspark.sql.types import IntegerType

from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.feature import IDF
from pyspark.sql.functions import length

from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vector
from pyspark.ml import Pipeline

from pyspark.ml.classification import NaiveBayes, LogisticRegression, RandomForestClassifier, GBTClassifier, LinearSVC
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

 

토큰화

countTokenizer = udf(lambda w: len(w), IntegerType())

 

데이터 불러오기 및 스파크 세션 시작

spark = SparkSession.builder.appName("nlp_nb").getOrCreate()

df = spark.read.csv("SMSSpamCollection",inferSchema=True,sep="\t")

# 컬럼명 변경
df = df.withColumnRenamed("_c0","label").withColumnRenamed("_c1","messages")
df.show()
df = df.withColumn("length",length(df["messages"]))
df.show()

df.groupby("label").mean().show()

tokenizer = Tokenizer(inputCol="messages",outputCol="tokened")
stop_word_remover = StopWordsRemover(inputCol="tokened",outputCol="stoped")
count_vec = CountVectorizer(inputCol="stoped",outputCol="c_vec")
idf = IDF(inputCol="c_vec",outputCol="tf_idf")
ham_spam_to_num = StringIndexer(inputCol="label",outputCol="label_01")

cleaned = VectorAssembler(inputCols=['tf_idf','length'],outputCol="features")

 

데이터 전처리

파이프라인 = Pipeline(stages=[ham_spam_to_num, tokenizer,stop_word_remover,count_vec,idf,cleaned])

전처리기 = 파이프라인.fit(df)
전처리_df = 전처리기.transform(df)
전처리_df.show()

 

모델 적용 및 평가

최종_df = 전처리_df.select(['label_01','features']).withColumnRenamed("label_01","label")

(train_df, test_df) = 최종_df.randomSplit([.75,.25], seed= 316)
train_df.show()

# 1. 나이브 베이즈
nb = NaiveBayes()
적합모형 = nb.fit(train_df)
적합결과 = 적합모형.transform(test_df)

eval = MulticlassClassificationEvaluator()
acc = eval.evaluate(적합결과)
print(f"Accuracy:{acc*100}")