最近想整整推荐系统,比较经典的算法就是SVD了。具体理论不多讲了。直接上代码。
先贴张效果图吧。userNum 6040 itemNum 3900
本文链接:http://www.cnblogs.com/wn19910213/p/3617781.html
上代码咯:
SVD.h
1 #ifndef SVD_H_INCLUDED
2 #define SVD_H_INCLUDED
3
4 #include <vector>
5 #include <cstring>
6
7 using namespace std;
8
9 class SVD{
10 public:
11 SVD(double*,double*,int,double**,double**);
12 ~SVD();
13
14 void loadTrainFile(string);
15 double predictScore(int,double,double,double*,double*);
16 double Validate(string,double,double*,double*,double**,double**);
17 // private:
18 double* Bi;
19 double* Bu;
20 int factor;
21 double** Qi;
22 double** Pu;
23 };
24
25
26 #endif // SVD_H_INCLUDED
SVD.cpp
1 #include <cmath>
2 #include <iostream>
3 #include <cstring>
4 #include <cstdlib>
5 #include <fstream>
6 #include "SVD.h"
7
8
9 int userNum = 6040;
10 int itemNum = 3900;
11 double AVG = 3.579231;
12 double lr = 0.01;
13 double theta = 0.05;
14 double preRmse = 1000000.0;
15
16 int main()
17 {
18 string trainFile = "/home/ja/CADATA/SVD/ml_data/training.txt";
19 string testFile = "/home/ja/CADATA/SVD/ml_data/test.txt";
20 srand(0);
21 SVD svd(NULL,NULL,0,NULL,NULL);
22
23 for(size_t i=0;i<100;i++){
24 svd.loadTrainFile(trainFile);
25 //lr *= 0.9;
26 double curRmse = svd.Validate(testFile,AVG,svd.Bu,svd.Bi,svd.Pu,svd.Qi);
27 cout << "test_Rmse in step " << i << ": " << curRmse << endl;
28 if(curRmse >= preRmse){
29 break;
30 }
31 else{
32 preRmse = curRmse;
33 }
34 }
35 return 0;
36 }
37
38 double SVD::Validate(string testfile,double avg,double* bu,double* bi,double** pu,double** qi){
39 ifstream fin(testfile.c_str());
40 if(!fin){
41 cout << "error" << endl;
42 }
43 int userId,itemId,rating,t;
44 int n = 0;
45 double rmse;
46 while(fin >> userId >> itemId >> rating >> t){
47 n++;
48 double pScore = predictScore(avg,bu[userId-1],bi[itemId-1],pu[userId-1],qi[itemId-1]);
49 rmse += (rating - pScore) * (rating - pScore);
50 }
51 fin.close();
52 return sqrt(rmse/n);
53 }
54
55 double SVD::predictScore(int avg,double bu,double bi,double* pu,double* qi){
56 double tmp = 0.0;
57 for(size_t i=0;i<factor;i++){
58 tmp += pu[i] * qi[i];
59 }
60
61 double score = avg + bu + bi + tmp;
62 if(score > 5){
63 score = 5;
64 }
65 if(score < 1){
66 score = 1;
67 }
68 return score;
69 }
70
71 void SVD::loadTrainFile(string file){
72 ifstream fin(file.c_str());
73 if(!fin){
74 cout << "error" << endl;
75 }
76
77 int userId,itemId,rating,t;
78 while(fin >> userId >> itemId >> rating >> t){
79 double predict = predictScore(AVG,Bu[userId-1],Bi[itemId-1],Pu[userId-1],Qi[itemId-1]);
80 double error = rating - predict;
81 Bu[userId-1] += lr * (error - theta * Bu[userId-1]);
82 Bi[itemId-1] += lr * (error - theta * Bi[itemId-1]);
83
84 for(size_t i=0;i<factor;i++){
85 double Tmp = Pu[userId-1][i];
86 Pu[userId-1][i] += lr * (error * Qi[itemId-1][i] - theta * Pu[userId-1][i]);
87 Qi[itemId-1][i] += lr * (error * Tmp - theta * Qi[itemId-1][i]);
88 }
89 }
90 fin.close();
91 }
92
93 SVD::SVD(double* bi,double* bu,int k,double** qi,double** pu){
94
95 if(bi == NULL){
96 Bi = new double[itemNum];
97 for(size_t i=0;i<itemNum;i++){
98 Bi[i] = 0.0;
99 }
100 }
101 else{
102 Bi = bi;
103 }
104
105 if(bu == NULL){
106 Bu = new double[userNum];
107 for(size_t i=0;i<userNum;i++){
108 Bu[i] = 0.0;
109 }
110 }
111 else{
112 Bu = bu;
113 }
114
115 factor = 10;
116
117 if(qi == NULL){
118 Qi = new double* [itemNum];
119 for(size_t i=0;i<itemNum;i++){
120 Qi[i] = new double[factor];
121 }
122
123 for(size_t i=0;i<itemNum;i++){
124 for(size_t j=0;j<factor;j++){
125 Qi[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor);
126 }
127 }
128 }
129 else{
130 Qi = qi;
131 }
132
133 if(pu == NULL){
134 Pu = new double* [userNum];
135 for(size_t i=0;i<userNum;i++){
136 Pu[i] = new double[factor];
137 }
138
139 for(size_t i=0;i<userNum;i++){
140 for(size_t j=0;j<factor;j++){
141 Pu[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor);
142 }
143 }
144 }
145 else{
146 Pu = pu;
147 }
148 }
149
150 SVD::~SVD(){
151 delete[] Bi;
152 delete[] Bu;
153 for(size_t i=0;i<userNum;i++){
154 delete[] Pu[i];
155 }
156 for(size_t i=0;i<itemNum;i++){
157 delete[] Qi[i];
158 }
159 delete[] Pu;
160 delete[] Qi;
161 }
原文链接: https://www.cnblogs.com/wn19910213/p/3617781.html
欢迎关注
微信关注下方公众号,第一时间获取干货硬货;公众号内回复【pdf】免费获取数百本计算机经典书籍
原创文章受到原创版权保护。转载请注明出处:https://www.ccppcoding.com/archives/124379
非原创文章文中已经注明原地址,如有侵权,联系删除
关注公众号【高性能架构探索】,第一时间获取最新文章
转载文章受原作者版权保护。转载请注明原作者出处!