KTV歌曲推荐

Wesley13
• 阅读 645

前言

上一篇使用逻辑回归预测了用户性别,由于矩阵比较稀疏所以会影响训练速度。所以考虑降维,降维方案有很多,本次只考虑PCA和SVD。

PCA和SVD原理

有兴趣的可以自己去研究一下 https://medium.com/@jonathan_hui/machine-learning-singular-value-decomposition-svd-principal-component-analysis-pca-1d45e885e491

我简述一下:

  • PCA是将高维数据映射到低维坐标系中,让数据尽量稀疏
  • SVD就是非方阵的PCA
  • 实际使用中SVD和PCA并无太大区别
  • 如果特征大于数据记录数,并不能有好的效果,具体原因自己可以去看。

代码

数据获取和处理

以前文章写过很多次,这里略过 原数据shape为:2000*1900

PCA和矩阵转换

查看最佳维度数

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
pca = PCA().fit(song_hot_matrix)
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel('number of components')
plt.ylabel('cumulative explained variance');

KTV歌曲推荐

从图中可以看出大概1500维度已经可以达到90+解释性

保留99%矩阵解释性

pca = PCA(n_components=0.99, whiten=True)
song_hot_matrix_pca = pca.fit_transform(song_hot_matrix)

得到压缩后特征为: 2000*1565 并没有压缩多少

模型训练

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Embedding,Flatten,Dropout
import matplotlib.pyplot as plt
from keras.utils import np_utils
from sklearn import datasets
from sklearn.model_selection import train_test_split

n_class=user_decades_encoder.get_class_count()
song_count=song_label_encoder.get_class_count()
print(n_class)
print(song_count)

train_X,test_X, train_y, test_y = train_test_split(song_hot_matrix_pca,
                                                   decades_hot_matrix,
                                                   test_size = 0.2,
                                                   random_state = 0)
train_count = np.shape(train_X)[0]
# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=song_hot_matrix_pca.shape[1], units=n_class))
model.add(Activation('softmax'))

# 选定loss函数和优化器
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

# 训练过程
print('Training -----------')
for step in range(train_count):
    scores = model.train_on_batch(train_X, train_y)
    if step % 50 == 0:
        print("训练样本 %d 个, 损失: %f, 准确率: %f" % (step, scores[0], scores[1]*100))
print('finish!')

训练结果:

训练样本 4750 个, 损失: 0.371499, 准确率: 83.207470
训练样本 4800 个, 损失: 0.381518, 准确率: 82.193959
训练样本 4850 个, 损失: 0.364363, 准确率: 83.763909
训练样本 4900 个, 损失: 0.378466, 准确率: 82.551670
训练样本 4950 个, 损失: 0.391976, 准确率: 81.756759
训练样本 5000 个, 损失: 0.378810, 准确率: 83.505565

测试集验证:

# 准确率评估
from sklearn.metrics import classification_report
scores = model.evaluate(test_X, test_y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))


Y_test = np.argmax(test_y, axis=1)
y_pred = model.predict_classes(song_hot_matrix_pca.transform(test_X))
print(classification_report(Y_test, y_pred))

accuracy: 50.20%

很明显已经过拟合

处理过拟合-增加Dropout

这里使用加Dropout,随机丢弃特征的方式处理过拟合,代码:

# 构建神经网络模型
model = Sequential()
model.add(Dropout(0.5))
model.add(Dense(input_dim=song_hot_matrix_pca.shape[1], units=n_class))
model.add(Activation('softmax'))

accuracy:70%

处理过拟合-L1L2正则

这里给权重增加正则

# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim=song_hot_matrix_pca.shape[1], units=n_class, kernel_regularizer=regularizers.l2(0.01)))
model.add(Activation('softmax'))

accuracy:62%

Well Done

其实SVD的做法与PCA类似,这里不再演示。经过我测试发现,在我的数据集上,PCA虽然加快了训练速度,但是丢弃了太多特征,导致数据很容易过拟合。加入Dropout或者增加正则相可以改善过拟合的情况,下一篇会分享自编码降维。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
6个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Easter79 Easter79
3年前
Twitter的分布式自增ID算法snowflake (Java版)
概述分布式系统中,有一些需要使用全局唯一ID的场景,这种时候为了防止ID冲突可以使用36位的UUID,但是UUID有一些缺点,首先他相对比较长,另外UUID一般是无序的。有些时候我们希望能使用一种简单一些的ID,并且希望ID能够按照时间有序生成。而twitter的snowflake解决了这种需求,最初Twitter把存储系统从MySQL迁移
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
3年前
Docker 部署SpringBoot项目不香吗?
  公众号改版后文章乱序推荐,希望你可以点击上方“Java进阶架构师”,点击右上角,将我们设为★“星标”!这样才不会错过每日进阶架构文章呀。  !(http://dingyue.ws.126.net/2020/0920/b00fbfc7j00qgy5xy002kd200qo00hsg00it00cj.jpg)  2
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这