Classifying Dead Stars

Pulsar classification is a great example of where machine learning can be used beneficially in astrophysics. It’s not the most straightforward classification problem, but here I’m going to outline the basics using the scikit-learn random forest classifier. This post was inspired by Rob Lyon‘s pulsar classification tutorials in the IAU OAD Data Science Toolkit.

This post is in Python 3.

I see dead… stars?

Pulsars are “pulsating radio sources”, now known to be caused by rapidly rotating neutron stars. Neutron stars are the relics of dead massive stars, they’re small and extremely dense – think about something the same mass as the Sun crammed into a radius roughly the same as the M25 motorway. You can read all about them here.

Enter aAn artist’s impression of a pulsar. Image credit: Joeri van Leeuwen, License: CC-BY-AS caption

You can even listen to them (if you really want to…)

PSR B0329+54: This pulsar is a typical, normal pulsar, rotating with a period of 0.714519 seconds, i.e. close to 1.40 rotations/sec.

Pulsars are pretty interesting objects in their own right, they are used as a probe of stellar evolution as well as being used to test general relativity due to their extremely high densities. These days they’re also used to detect and map gravitational wave signatures. However, identifying them in the data streams from radio telescopes is not trivial. There are lots of man-made sources of radio frequency interference that can mimic the signals from pulsars. Classifying candidate data samples as pulsar or not pulsar is serious business.

The individual pulses are all different, so astronomers stack them up and create an average integrated pulse profile to characterise a particular pulsar:

Essentials of Radio Astronomy

Additionally the pulse will arrive at different times across different radio frequencies. The delay from frequency to frequency is caused by the ionised inter-stellar medium and is known as the dispersion. It looks like this:

Essentials of Radio Astronomy

Astronomers fit for the shape of the delay in order to compensate for its effect, but there’s always an uncertainty associated with the fit. That is expressed in the DM-SNR (“dispersion-measure-signal-to-noise-ratio”) curve, which looks like this:


When you put these two curves together it means that for each pulsar candidate there are eight numerical features that can be extracted as standard: four from the integrated pulse profile and four from the DM-SNR curve:


Getting set-up

First some general libraries:

import numpy as np   # for array stuff
import pylab as pl   # for plotting stuff
import pandas as pd  # for data handling

Then a bunch of scikit-learn libraries:

from sklearn.ensemble import RandomForestClassifier
from sklearn import model_selection
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import roc_curve, roc_auc_score

I’m also using scikit-plot, which I only recently discovered and has made my life much easier 🙂

import scikitplot as skplt

I’m using the HTRU2 dataset. This dataset compiles the eight features described above for both 1,639 true known pulsars, as well as 16,259 additional candidate pulsars later identified to be RFI/noise. You can find a full description of the dataset in this paper.

I added a row to the CSV for the feature names for the purpose of this example – you can find my version in the IAU OAD Data Science Toolkit here.

df = pd.read_csv('data/pulsar.csv')

You can take a look at the names of the features in the file like this (pf = integrated profile & dm = DM-SNR curve):

feature_names = df.columns.values[0:-1]

['mean_int_pf' 'std_pf' 'ex_kurt_pf' 'skew_pf' 'mean_dm' 'std_dm'
'kurt_dm' 'skew_dm']

and we can check just how much data we’re dealing with:

print ('Dataset has %d rows and %d columns including features and labels'%(df.shape[0],df.shape[1]))

Dataset has 17898 rows and 9 columns including features and labels

We’re going to start by separating the numerical feature data from the class labels for all the candidates. To get the feature data on its own we can just strip off the column containing the class labels:

features = df.drop('class', axis=1)

The labels for each object tell us abut the target class and we can create an array of those data by extracting the column from the original dataset:

targets = df['class']

Setting up the Machine Learning

Now we need to split our labelled data into two separate datasets: one to train the classifier and one to test the fitted machine learning model. To do this we can use the function train_test_split from the scikit_learn library:

X_train, X_test, y_train, y_test = train_test_split(features, targets, test_size=0.33, random_state=66)

At this point we now have our dataset in a suitable state to start training the classifier.

To start with we need to initiate the random forest classifier from scikit_learn:

RFC = RandomForestClassifier(n_jobs=2,n_estimators=10)

…and we can immediately fit the machine learning model to our training data:,y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=2,
oob_score=False, random_state=None, verbose=0,

We can then used the trained classifier to predict the label for the test data that we split out earlier:

rfc_predict = RFC.predict(X_test)

Evaluating Performance

So how did we do? We need to evaluate the performance of our classifier.
A good first step is to evaluate the cross-validation. This will tell us how well our machine learning model generalises, i.e. whether we have over-fitted the training data.

rfc_cv_score = cross_val_score(RFC, features, targets, cv=10, scoring='roc_auc')

Let’s print out the various evaluation criteria:

print("=== Confusion Matrix ===")
print(confusion_matrix(y_test, rfc_predict))
print("=== Classification Report ===")
print(classification_report(y_test, rfc_predict, target_names=['Non Pulsar','Pulsar']))
print("=== All AUC Scores ===")
print("=== Mean AUC Score ===")
print("Mean AUC Score - Random Forest: ", rfc_cv_score.mean())

=== Confusion Matrix ===
[[5327 35]
[ 93 452]]

=== Classification Report ===
precision recall f1-score support

Non Pulsar 0.98 0.99 0.99 5362
Pulsar 0.93 0.83 0.88 545

micro avg 0.98 0.98 0.98 5907
macro avg 0.96 0.91 0.93 5907
weighted avg 0.98 0.98 0.98 5907

=== All AUC Scores ===
[0.92774615 0.94807886 0.96225025 0.96079711 0.96652717 0.9472501
0.96336963 0.95761145 0.96597591 0.96716753]

=== Mean AUC Score ===
Mean AUC Score - Random Forest: 0.956677415292086

We can make a more visual representation of the confusion matrix using the scikit-plot library. To do this we need to know the predictions from our cross validation, rather than the Area Under Curve (AUC) value:

predictions = cross_val_predict(RFC, features, targets, cv=2)
skplt.metrics.plot_confusion_matrix(targets, predictions, normalize=True)


To plot the ROC curve we need to find the probabilities for each target class separately. We can do this with the predict_proba function:

probas = RFC.predict_proba(X_test)
skplt.metrics.plot_roc(y_test, probas)

In a balanced data set there should be no difference between the micro-average ROC curve and the macro-average ROC curve. In the case where there is a class imbalance (like here), if the macro ROC curve is lower than the micro-ROC curve then there are more cases of mis-classification in minority class.


We can use the output of the RFC.predict_proba( ) function to plot a Precision-Recall Curve.

skplt.metrics.plot_precision_recall(y_test, probas)


Ranking the Features

Let’s take a look at the relative importance of the different features that we fed to our classifier:

importances = RFC.feature_importances_
indices = np.argsort(importances)
pl.title('Feature Importances')
pl.barh(range(len(indices)), importances[indices], color='b', align='center')
pl.yticks(range(len(indices)), feature_names[indices])
pl.xlabel('Relative Importance')



Then for the blog this.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: