notebook
pyLDAvis
.transform
spark LDA example data set
# Code to train LDA model using spark ml
from pyspark.ml.clustering import LDA
from pyspark.sql.types import DoubleType
from pyspark.sql import functions as F
# Loads data
dataset = spark.read.format("libsvm").load("file:///usr/sample_lda_libsvm_data.txt")
dataset.show(truncate=False)
dataset.show(truncate=False)
+-----+---------------------------------------------------------------+
|label|features |
+-----+---------------------------------------------------------------+
|0.0 |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0]) |
|1.0 |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0]) |
|2.0 |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0]) |
|3.0 |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0]) |
|4.0 |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0]) |
|5.0 |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |
|6.0 |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0]) |
|7.0 |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|
|8.0 |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0]) |
|9.0 |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0]) |
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0]) |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0]) |
+-----+---------------------------------------------------------------+
# Trains a LDA model
lda = LDA(k=10, maxIter=10)
model = lda.fit(dataset)
# Describe topics.
topics = model.describeTopics(3)
print("The topics described by their top-weighted terms:")
topics.show(truncate=False)
+-----+-----------+---------------------------------------------------------------+
|topic|termIndices|termWeights |
+-----+-----------+---------------------------------------------------------------+
|0 |[4, 7, 10] |[0.10782284792565977, 0.09748059037449146, 0.09623493647157101]|
|1 |[1, 6, 9] |[0.16755678146051728, 0.14746675884135615, 0.12291623854765772]|
|2 |[3, 10, 6] |[0.2365737123772152, 0.10497827056720986, 0.0917840535687615] |
|3 |[1, 3, 7] |[0.1015758016249506, 0.09974496621850018, 0.09902599541011434] |
|4 |[9, 10, 3] |[0.10479879348457938, 0.10207370742688827, 0.09818478669740321]|
|5 |[8, 5, 7] |[0.10843493028120557, 0.0970150424500599, 0.09334497822531877] |
|6 |[8, 5, 0] |[0.09874156962344234, 0.09654280831555884, 0.09565956823827508]|
|7 |[9, 4, 7] |[0.11252483000458603, 0.09755087587088286, 0.09643430900592685]|
|8 |[4, 1, 2] |[0.10994283713713536, 0.09410686873447463, 0.0937471573628509] |
|9 |[5, 4, 0] |[0.15265940066441183, 0.14015412109446546, 0.13878634876078264]|
+-----+-----------+---------------------------------------------------------------+
# view topic distribution for every document
transformed = model.transform(dataset)
transformed.show(truncate=False)
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|features |topicDistribution |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0.0 |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0]) |[0.004830688509084788,0.9563375886321935,0.004924669693727129,0.004830693291141946,0.004830675601199576,0.004830690970098452,0.004830731737552684,0.004830674902568036,0.004830730786933749,0.004922855875500012] |
|1.0 |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0]) |[0.008057778755383592,0.3149188541525326,0.00821568856074705,0.008057899973735082,0.00805773202965193,0.00805773219443841,0.00805772753178338,0.008057790266770967,0.008057845264839285,0.6204609512701176] |
|2.0 |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0]) |[0.004199741171245032,0.9620401773226402,0.004281469704273017,0.004199769097486346,0.004199807571784884,0.004199819505813106,0.004199835506062414,0.004199781772904878,0.004199800982100323,0.004279797365689855] |
|3.0 |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0]) |[0.003714896800546591,0.5070516557688054,0.4631584573147577,0.003714914880264338,0.0037150085177011572,0.003714949896828997,0.0037149846555122436,0.003714886267751718,0.003714909060953893,0.003785336836878225] |
|4.0 |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0]) |[0.004024716198633711,0.004348960756766257,0.9633765414688664,0.004024715826289515,0.0040247523412803785,0.004024714760590197,0.004024750967476446,0.004024750137766685,0.004024763598734582,0.004101333943595805] |
|5.0 |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.003714916720108325,0.004014106400247752,0.0037876992243613913,0.0037149522531312196,0.0037149927030871474,0.0037149587146134535,0.0037149750439419123,0.0037150099006180567,0.003714963609773339,0.9661934254301174] |
|6.0 |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0]) |[0.003863637584067354,0.44120209378688086,0.5278152614977222,0.0038636593932357263,0.003863751204372584,0.0038636970054184935,0.003863731528120536,0.0038636169190041057,0.003863652151710295,0.003936898929468125] |
|7.0 |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.004390955723890411,0.004745014492795635,0.9600436030532219,0.004390986523517605,0.004391013571891052,0.004390968206875746,0.004391003804300225,0.004390998289212864,0.0043910030406065104,0.004474453293687847] |
|8.0 |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0]) |[0.004391082468515706,0.004744799620819518,0.004477230286216996,0.004391179034422902,0.004391083385391976,0.0043911102087152145,0.004391108242443274,0.0043911476110250714,0.0043911508747108575,0.9600401082677386] |
|9.0 |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0]) |[0.0033302167739046973,0.9698998050463385,0.0033949933226572675,0.0033302031974203014,0.0033302208173504686,0.003330228671311114,0.0033302277108795157,0.003330230056473623,0.0033302455331591036,0.0033936288705052665]|
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0]) |[0.0041998552715806015,0.004538086674649772,0.9617828003374762,0.0041998854155415434,0.004199964563679233,0.004199898040748559,0.004199948969028732,0.004199941207400563,0.004199894377993083,0.004279725141901989] |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0]) |[0.0048305604098789244,0.005219225001032762,0.004924487214200011,0.004830543265675906,0.00483056515654878,0.004830577688731923,0.004830590528195045,0.004830599936989683,0.004830615233900232,0.9560422355648467] |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
transformed.printSchema()
root
|-- label: double (nullable = true)
|-- features: vector (nullable = true)
|-- topicDistribution: vector (nullable = true)
topicDistribution
helper function
def ith_(v, i):
try:
return float(v[i])
except ValueError:
return None
ith = F.udf(ith_, DoubleType())
df = transformed.select(["label"] + [ith("topicDistribution", F.lit(i)).alias('topic_'+str(i)) for i in range(10)] )
df.show(truncate=False)
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|label|topic_0 |topic_1 |topic_2 |topic_3 |topic_4 |topic_5 |topic_6 |topic_7 |topic_8 |topic_9 |
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|0.0 |0.004830687791450502 |0.9563377999372255 |0.004830652446299898 |0.004830693203685635 |0.004924680975321234 |0.004830690324650106 |0.004830724790894176 |0.004830674545741453 |0.004830728328369402 |0.00492266765636222 |
|1.0 |0.00805777782592821 |0.3150888304586096 |0.008057821375392899 |0.008057900091752447 |0.00821563090347786 |0.008057731378987427 |0.008057716226340182 |0.00805778996991863 |0.008057841440203276 |0.6202909603293896 |
|2.0 |0.004199740539975822 |0.9620403414727842 |0.004199830281319767 |0.004199769011855544 |0.004281446354869374 |0.004199818930938506 |0.004199829456280457 |0.004199781450899189 |0.004199798835689997 |0.00427964366538733 |
|3.0 |0.003714883352496639 |0.39438266523895776 |0.0037149161634889914|0.003714899290148889 |0.5758276298046127 |0.003714939245435922 |0.0037149657297638815|0.003714878209574761 |0.0037148981104253493|0.0037853248550950695|
|4.0 |0.00402472343811409 |0.0043486720544167945|0.0040247584323080295|0.004024726616022349 |0.9633767817635327 |0.004024722506471514 |0.004024749723387701 |0.004024759068339994 |0.00402477228684825 |0.0041013341105585275|
|5.0 |0.0037149161731463167|0.00401410657859215 |0.0037150318186438148|0.003714952190974752 |0.0037876713720541993|0.003714958223027372 |0.003714969707955506 |0.0037150096299263177|0.003714961725756829 |0.9661934225799228 |
|6.0 |0.0038636235465470963|0.32506932380193027 |0.0038636563625666425|0.003863644344443025 |0.6439482136665527 |0.0038636867164242353|0.003863712160357752 |0.003863609226073573 |0.003863641557265962 |0.00393688861783849 |
|7.0 |0.004390963901259502 |0.004744419369141901 |0.004391020228883301 |0.00439099927884862 |0.9600441405838983 |0.004390977425037901 |0.004391002809855065 |0.004391008592998927 |0.004391013090740394 |0.004474454719336111 |
|8.0 |0.004391081853379135 |0.004744865767572997 |0.004391206214702098 |0.004391178993516226 |0.004477132667794462 |0.0043911096593825015|0.0043911019675074445|0.004391147323286589 |0.0043911486798455125|0.960040026873013 |
|9.0 |0.003330216240957084 |0.9698999783457445 |0.00333023738785573 |0.0033302030986131904|0.003394973102900875 |0.0033302280874212362|0.0033302228867079335|0.0033302291785187624|0.0033302391644247616|0.003393472506855918 |
|10.0 |0.004199858865711682 |0.004538534384183169 |0.004199958349762097 |0.004199894260340701 |0.9617823390796781 |0.004199903494953782 |0.0041999446501473445|0.004199945557171458 |0.004199899755712464 |0.004279721602339041 |
|11.0 |0.00483055973980833 |0.005219211145215135 |0.004830592303351509 |0.004830543225945144 |0.004924458988916403 |0.004830577090650675 |0.004830583633398643 |0.004830599625982923 |0.004830612825588896 |0.9560422614211423 |
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+