-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathapp.py
More file actions
183 lines (150 loc) · 7.47 KB
/
app.py
File metadata and controls
183 lines (150 loc) · 7.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# Set page configuration for a better layout
st.set_page_config(page_title="Game Churn Prediction AI", layout="wide")
st.title("Game Churn Prediction AI")
st.markdown("""
Welcome to the Game Churn Prediction AI. This tool allows you to upload player gameplay and engagement data
to predict which players are at risk of leaving (churning).
""")
# Load the trained model and expected feature names
@st.cache_resource
def load_model():
try:
model = joblib.load("models/churn_model.pkl")
features = joblib.load("models/model_features.pkl")
return model, features
except FileNotFoundError:
return None, None
model, expected_features = load_model()
if model is None:
st.error("Model not found! Please ensure 'models/churn_model.pkl' and 'models/model_features.pkl' exist.")
else:
#upload-chart
st.sidebar.header("1. Upload Data")
uploaded_file = st.sidebar.file_uploader("Upload a CSV file with player data", type=["csv"])
# output-chart
st.sidebar.markdown("""
**Required Columns (Example):**
- Age, Gender, Location
- GameGenre, PlayTimeHours
- InGamePurchases, GameDifficulty
- SessionsPerWeek, AvgSessionDurationMinutes
- PlayerLevel, AchievementsUnlocked
""")
if uploaded_file is not None:
# Load the user uploaded dataset
st.subheader("Uploaded Data Preview")
try:
df_raw = pd.read_csv(uploaded_file)
st.dataframe(df_raw.head())
# --- Preprocessing the uploaded data ---
df_processed = df_raw.copy()
# 1. Handle missing values
df_processed.fillna(df_processed.median(numeric_only=True), inplace=True)
for col in df_processed.select_dtypes(include=['object']).columns:
df_processed[col] = df_processed[col].fillna(df_processed[col].mode()[0])
# 2. Drop PlayerID if exists
if 'PlayerID' in df_processed.columns:
df_processed.drop(columns=['PlayerID'], inplace=True)
# If 'EngagementLevel' or 'Churn' exists, we should likely remove it for prediction
# but usually prediction data wouldn't have the target. Let's drop them to be safe.
if 'EngagementLevel' in df_processed.columns:
df_processed.drop(columns=['EngagementLevel'], inplace=True)
if 'Churn' in df_processed.columns:
df_processed.drop(columns=['Churn'], inplace=True)
# 3. One-hot encoding
df_processed = pd.get_dummies(df_processed, drop_first=True)
# 4. Align features with the trained model
# We create a dataframe with the expected columns, filled with 0s
X = pd.DataFrame(columns=expected_features)
# Fill in the data from the processed dataframe
for col in expected_features:
if col in df_processed.columns:
X[col] = df_processed[col]
else:
X[col] = 0 # Feature missing in the uploaded file, fill with 0
# Make predictions
probabilities = model.predict_proba(X)
# Probability of churning is usually the second class (index 1)
churn_probs = probabilities[:, 1]
# Attach predictions to the original dataframe to display
df_raw['Churn Probability'] = churn_probs
# Define Risk Level based on probability
def get_risk_level(prob):
if prob < 0.4:
return 'Low'
elif prob < 0.7:
return 'Medium'
else:
return 'High'
df_raw['Risk Level'] = df_raw['Churn Probability'].apply(get_risk_level)
st.subheader("Prediction Results")
st.markdown("Here is the estimated churn probability and risk level for each player.")
# --- Churn Risk Explanation Box ---
col1, col2, col3 = st.columns(3)
with col1:
st.success("Low Risk → churn probability below 40%")
with col2:
st.warning("Medium Risk → churn probability between 40% and 70%")
with col3:
st.error("High Risk → churn probability above 70%")
# Style the dataframe for better visualization
pd.set_option("styler.render.max_elements", 4000000)
def color_risk(val):
color = 'green' if val == 'Low' else 'orange' if val == 'Medium' else 'red'
return f'color: {color}; font-weight: bold'
try:
# pandas >= 2.1.0 uses .map
styled_df = df_raw[['Churn Probability', 'Risk Level'] + list(df_raw.columns[:-2])].style.map(
color_risk, subset=['Risk Level']
)
except AttributeError:
# fallback for older versions
styled_df = df_raw[['Churn Probability', 'Risk Level'] + list(df_raw.columns[:-2])].style.applymap(
color_risk, subset=['Risk Level']
)
st.dataframe(styled_df)
# --- Feature Importance ---
st.subheader("Model Feature Importance")
st.markdown("""
What drives churn? Check out the most important factors considered by our Random Forest model.
""")
importances = model.feature_importances_
feat_imp_df = pd.DataFrame({
'Feature': expected_features,
'Importance': importances
}).sort_values(by='Importance', ascending=False)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=feat_imp_df.head(10), ax=ax, palette='viridis')
ax.set_title('Top 10 Drivers of Churn')
st.pyplot(fig)
# --- Auto-Generated Feature Importance Insights ---
top_features = feat_imp_df.head(3)['Feature'].tolist()
st.info(f"**Insights:** The model indicates churn is primarily driven by: **{', '.join(top_features)}**.")
# Check if demographic features are driving churn
demographic_keywords = ['Age', 'Gender', 'Location']
engagement_is_key = True
for feat in top_features:
if any(dem_key in feat for dem_key in demographic_keywords):
engagement_is_key = False
break
if engagement_is_key:
st.caption("Note: Engagement behavior is more influential than demographics (e.g., age, gender, location).")
# Download predicted results
st.subheader("Download Predictions")
csv = df_raw.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download Results as CSV",
data=csv,
file_name='churn_predictions.csv',
mime='text/csv',
)
except Exception as e:
st.error(f"Error processing the file: {e}")
else:
st.info("Please upload a CSV file from the sidebar to see predictions.")